diff --git a/.secrets.baseline b/.secrets.baseline index 4d88ce93..e388590e 100644 --- a/.secrets.baseline +++ b/.secrets.baseline @@ -139,6 +139,16 @@ } ], "results": { + ".github/workflows/create_release.yml": [ + { + "type": "Secret Keyword", + "filename": ".github/workflows/create_release.yml", + "hashed_secret": "3e26d6750975d678acb8fa35a0f69237881576b0", + "is_verified": false, + "line_number": 15, + "is_secret": false + } + ], "docs/workflow_mq.html": [ { "type": "Base64 High Entropy String", @@ -150,5 +160,5 @@ } ] }, - "generated_at": "2024-09-18T09:54:14Z" + "generated_at": "2024-10-08T15:30:41Z" } diff --git a/HISTORY.md b/HISTORY.md index 4eec0e5e..c4890279 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -40,7 +40,7 @@ # 0.5.3 * FIX FragPipe loading issue -# 0.5.2 +# 0.5.2 * FIX FragPipe import #173 # 0.5.1 diff --git a/README.md b/README.md index 3a8be1f1..74d82f01 100644 --- a/README.md +++ b/README.md @@ -78,13 +78,20 @@ alphastats gui ``` If you get an `AxiosError: Request failed with status code 403'` when uploading files, try running `DISABLE_XSRF=1 alphastats gui`. -If you want to use local Large Language Models to help interpret the data, -you need to download and install ollama (https://ollama.com/download). The url of the server can be set by the -environmental variable `OLLAMA_BASE_URL` (defaults to `http://localhost:11434`) - AlphaStats can be imported as a Python package into any Python script or notebook with the command `import alphastats`. A brief [Jupyter notebook tutorial](nbs/getting_started.ipynb) on how to use the API is also present in the [nbs folder](nbs). +### LLM Support +If you want to use local Large Language Models to help interpret the data, +you need either a OpenAI API key (to use ChatGPT-4o) +or provide a server running Ollama. + +For provisioning Ollama, first download and install the runtime (https://ollama.com/download). +Then, pull the recommended model: `ollama pull llama3.1:70b`. + +By default, Ollama models are served at `http://localhost:11434`, which is also the default for AlphaPeptStats. +You can overwrite the url of the server by the environmental variable `OLLAMA_BASE_URL`. + ### One Click Installer @@ -148,6 +155,13 @@ You can run the checks yourself using: pre-commit run --all-files ``` +##### The `detect-secrets` hook fails +This is because you added some code that was identified as a potential secret. +1. Run `detect-secrets scan --exclude-files testfiles --exclude-lines '"(hash|id|image/\w+)":.*' > .secrets.baseline` +(check `.pre-commit-config.yaml` for the exact parameters) +2. Run `detect-secrets audit .secrets.baseline` and check if the detected 'secret' is actually a secret +3. Commit the latest version of `.secrets.baseline` + --- diff --git a/alphastats/DataSet.py b/alphastats/DataSet.py index b51f8bdb..c0025b20 100644 --- a/alphastats/DataSet.py +++ b/alphastats/DataSet.py @@ -1,17 +1,16 @@ -import logging -import warnings from typing import Dict, List, Optional, Tuple, Union -import numpy as np import pandas as pd import plotly import scipy from alphastats import BaseLoader from alphastats.dataset_factory import DataSetFactory +from alphastats.dataset_harmonizer import DataHarmonizer from alphastats.DataSet_Plot import Plot from alphastats.DataSet_Preprocess import Preprocess from alphastats.DataSet_Statistics import Statistics +from alphastats.keys import Cols from alphastats.plots.ClusterMap import ClusterMap from alphastats.plots.DimensionalityReduction import DimensionalityReduction from alphastats.plots.IntensityPlot import IntensityPlot @@ -65,18 +64,22 @@ def __init__( """ self._check_loader(loader=loader) + self._data_harmonizer = DataHarmonizer( + loader, sample_column + ) # TODO should be moved to the loaders + # fill data from loader - self.rawinput: pd.DataFrame = loader.rawinput + self.rawinput: pd.DataFrame = self._data_harmonizer.get_harmonized_rawinput( + loader.rawinput + ) self.filter_columns: List[str] = loader.filter_columns - self.index_column: str = loader.index_column self.software: str = loader.software - self._gene_names: str = loader.gene_names - self._intensity_column: Union[str, list] = ( loader._extract_sample_names( - metadata=self.metadata, sample_column=self.sample + metadata=self.metadata, sample_column=sample_column ) - if loader == "Generic" + if loader + == "Generic" # TODO is this ever the case? not rather instanceof(loader, GenericLoader)? else loader.intensity_column ) @@ -84,28 +87,43 @@ def __init__( self._dataset_factory = DataSetFactory( rawinput=self.rawinput, - index_column=self.index_column, intensity_column=self._intensity_column, metadata_path_or_df=metadata_path_or_df, - sample_column=sample_column, + data_harmonizer=self._data_harmonizer, ) - rawmat, mat, metadata, sample, preprocessing_info = self._get_init_dataset() + rawmat, mat, metadata, preprocessing_info = self._get_init_dataset() self.rawmat: pd.DataFrame = rawmat self.mat: pd.DataFrame = mat self.metadata: pd.DataFrame = metadata - self.sample: str = sample self.preprocessing_info: Dict = preprocessing_info + self._gene_name_to_protein_id_map = ( + { + k: v + for k, v in dict( + zip( + self.rawinput[Cols.GENE_NAMES].tolist(), + self.rawinput[Cols.INDEX].tolist(), + ) + ).items() + if isinstance(k, str) # avoid having NaN as key + } + if Cols.GENE_NAMES in self.rawinput.columns + else {} + ) + # TODO This is not necessarily unique, and should ideally raise an error in some of our test-data sets that + # contain isoform ids. E.g. TPM1 occurs 5 times in testfiles/maxquant/proteinGroups.txt with different base Protein IDs. + print("DataSet has been created.") def _get_init_dataset( self, - ) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame, str, Dict]: + ) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame, Dict]: """Get the initial data structure for the DataSet.""" rawmat, mat = self._dataset_factory.create_matrix_from_rawinput() - metadata, sample = self._dataset_factory.create_metadata(mat) + metadata = self._dataset_factory.create_metadata(mat) preprocessing_info = Preprocess.init_preprocessing_info( num_samples=mat.shape[0], @@ -114,9 +132,10 @@ def _get_init_dataset( filter_columns=self.filter_columns, ) - return rawmat, mat, metadata, sample, preprocessing_info + return rawmat, mat, metadata, preprocessing_info - def _check_loader(self, loader): + @staticmethod + def _check_loader(loader): """Checks if the Loader is from class AlphaPeptLoader, MaxQuantLoader, DIANNLoader, FragPipeLoader Args: @@ -143,8 +162,6 @@ def _get_preprocess(self) -> Preprocess: return Preprocess( self.filter_columns, self.rawinput, - self.index_column, - self.sample, self.metadata, self.preprocessing_info, self.mat, @@ -181,7 +198,6 @@ def reset_preprocessing(self): self.rawmat, self.mat, self.metadata, - self.sample, self.preprocessing_info, ) = self._get_init_dataset() @@ -194,8 +210,6 @@ def _get_statistics(self) -> Statistics: return Statistics( mat=self.mat, metadata=self.metadata, - index_column=self.index_column, - sample=self.sample, preprocessing_info=self.preprocessing_info, ) @@ -220,14 +234,13 @@ def diff_expression_analysis( def tukey_test(self, protein_id: str, group: str) -> pd.DataFrame: """A wrapper for tukey_test.tukey_test(), see documentation there.""" - df = self.mat[[protein_id]].reset_index().rename(columns={"index": self.sample}) - df = df.merge(self.metadata, how="inner", on=[self.sample]) + df = self.mat[[protein_id]].reset_index().rename(columns={"index": Cols.SAMPLE}) + df = df.merge(self.metadata, how="inner", on=[Cols.SAMPLE]) return tukey_test( df, protein_id, group, - self.index_column, ) def anova(self, column: str, protein_ids="all", tukey: bool = True) -> pd.DataFrame: @@ -240,6 +253,19 @@ def ancova( """A wrapper for Statistics.ancova(), see documentation there.""" return self._get_statistics().ancova(protein_id, covar, between) + def multicova_analysis( + self, + covariates: list, + n_permutations: int = 3, + fdr: float = 0.05, + s0: float = 0.05, + subset: dict = None, + ) -> Tuple[pd.DataFrame, list]: + """A wrapper for Statistics.multicova_analysis(), see documentation there.""" + return self._get_statistics().multicova_analysis( + covariates, n_permutations, fdr, s0, subset + ) + @check_for_missing_values def plot_pca(self, group: Optional[str] = None, circle: bool = False): """Plot Principal Component Analysis (PCA) @@ -254,7 +280,6 @@ def plot_pca(self, group: Optional[str] = None, circle: bool = False): dimensionality_reduction = DimensionalityReduction( mat=self.mat, metadata=self.metadata, - sample=self.sample, preprocessing_info=self.preprocessing_info, group=group, circle=circle, @@ -282,7 +307,6 @@ def plot_tsne( dimensionality_reduction = DimensionalityReduction( mat=self.mat, metadata=self.metadata, - sample=self.sample, preprocessing_info=self.preprocessing_info, group=group, method="tsne", @@ -306,7 +330,6 @@ def plot_umap(self, group: Optional[str] = None, circle: bool = False): dimensionality_reduction = DimensionalityReduction( mat=self.mat, metadata=self.metadata, - sample=self.sample, preprocessing_info=self.preprocessing_info, group=group, method="umap", @@ -314,6 +337,27 @@ def plot_umap(self, group: Optional[str] = None, circle: bool = False): ) return dimensionality_reduction.plot + def perform_dimensionality_reduction( + self, method: str, group: Optional[str] = None, circle: bool = False + ): + """Generic wrapper for dimensionality reduction methods to be used by LLM. + + Args: + method (str): "pca", "tsne", "umap" + group (str, optional): column in metadata that should be used for coloring. Defaults to None. + circle (bool, optional): draw circle around each group. Defaults to False. + """ + + result = { + "pca": self.plot_pca, + "tsne": self.plot_tsne, + "umap": self.plot_umap, + }.get(method) + if result is None: + raise ValueError(f"Invalid method: {method}") + + return result(group=group, circle=circle) + @ignore_warning(RuntimeWarning) def plot_volcano( self, @@ -328,7 +372,7 @@ def plot_volcano( perm: int = 100, fdr: float = 0.05, # compare_preprocessing_modes: bool = False, # TODO reimplement - color_list: list = [], + color_list: list = None, ): """Plot Volcano Plot @@ -344,7 +388,7 @@ def plot_volcano( perm(float,optional): number of permutations when using SAM as method. Defaults to 100. fdr(float,optional): FDR cut off when using SAM as method. Defaults to 0.05. color_list (list): list with ProteinIDs that should be highlighted. - compare_preprocessing_modes(bool): Will iterate through normalization and imputation modes and return a list of VolcanoPlots in different settings, Default False. + #compare_preprocessing_modes(bool): Will iterate through normalization and imputation modes and return a list of VolcanoPlots in different settings, Default False. Returns: @@ -360,13 +404,12 @@ def plot_volcano( # return results # # else: + if color_list is None: + color_list = [] volcano_plot = VolcanoPlot( mat=self.mat, rawinput=self.rawinput, metadata=self.metadata, - sample=self.sample, - index_column=self.index_column, - gene_names=self._gene_names, preprocessing_info=self.preprocessing_info, group1=group1, group2=group2, @@ -383,12 +426,35 @@ def plot_volcano( return volcano_plot.plot + def _get_protein_id_for_gene_name( + self, + gene_name: str, + ) -> str: + """Get protein id from gene id. If gene id is not present, return gene id, as we might already have a gene id. + 'VCL;HEL114' -> 'P18206;A0A024QZN4;V9HWK2;B3KXA2;Q5JQ13;B4DKC9;B4DTM7;A0A096LPE1' + + Args: + gene_name (str): Gene name + + Returns: + str: Protein id or gene name if not present in the mapping. + """ + if gene_name in self._gene_name_to_protein_id_map: + return self._gene_name_to_protein_id_map[gene_name] + + for gene, protein_id in self._gene_name_to_protein_id_map.items(): + if gene_name in gene.split(";"): + return protein_id + return gene_name + def plot_intensity( self, - protein_id: str, + *, + protein_id: str = None, + gene_name: str = None, group: str = None, subgroups: list = None, - method: str = "box", + method: str = "box", # TODO rename add_significance: bool = False, log_scale: bool = False, # compare_preprocessing_modes: bool = False, TODO reimplement @@ -396,7 +462,8 @@ def plot_intensity( """Plot Intensity of individual Protein/ProteinGroup Args: - protein_id (str): ProteinGroup ID + protein_id (str): ProteinGroup ID. Mutually exclusive with gene_name. + gene_name (str): Gene Name, will be mapped to a ProteinGroup ID. Mutually exclusive with protein_id. group (str, optional): A metadata column used for grouping. Defaults to None. subgroups (list, optional): Select variables from the group column. Defaults to None. method (str, optional): Violinplot = "violin", Boxplot = "box", Scatterplot = "scatter" or "all". Defaults to "box". @@ -414,10 +481,18 @@ def plot_intensity( # ) # return results + if gene_name is None and protein_id is not None: + pass + elif gene_name is not None and protein_id is None: + protein_id = self._get_protein_id_for_gene_name(gene_name) + else: + raise ValueError( + "Either protein_id or gene_name must be provided, but not both." + ) + intensity_plot = IntensityPlot( mat=self.mat, metadata=self.metadata, - sample=self.sample, intensity_column=self._intensity_column, preprocessing_info=self.preprocessing_info, protein_id=protein_id, @@ -454,8 +529,6 @@ def plot_clustermap( clustermap = ClusterMap( mat=self.mat, metadata=self.metadata, - sample=self.sample, - index_column=self.index_column, preprocessing_info=self.preprocessing_info, label_bar=label_bar, only_significant=only_significant, @@ -478,7 +551,6 @@ def _get_plot(self) -> Plot: self.mat, self.rawmat, self.metadata, - self.sample, self.preprocessing_info, ) @@ -489,7 +561,7 @@ def plot_correlation_matrix(self, method: str = "pearson"): # TODO unused def plot_sampledistribution( self, method: str = "violin", - color: str = None, + color: str = None, # TODO rename to group log_scale: bool = False, use_raw: bool = False, ): diff --git a/alphastats/DataSet_Plot.py b/alphastats/DataSet_Plot.py index d2214d46..5b110f1d 100644 --- a/alphastats/DataSet_Plot.py +++ b/alphastats/DataSet_Plot.py @@ -4,10 +4,10 @@ import plotly import plotly.express as px import plotly.figure_factory -import plotly.graph_objects as go import scipy import seaborn as sns +from alphastats.keys import Cols from alphastats.plots.PlotUtils import PlotUtils from alphastats.utils import check_for_missing_values @@ -51,13 +51,11 @@ def __init__( mat: pd.DataFrame, rawmat: pd.DataFrame, metadata: pd.DataFrame, - sample: str, preprocessing_info: Dict, ): self.mat: pd.DataFrame = mat self.rawmat: pd.DataFrame = rawmat self.metadata: pd.DataFrame = metadata - self.sample: str = sample self.preprocessing_info: Dict = preprocessing_info def plot_correlation_matrix(self, method: str = "pearson"): # TODO unused @@ -96,15 +94,18 @@ def plot_sampledistribution( # create long df matrix = self.mat if not use_raw else self.rawmat df = matrix.unstack().reset_index() - df.rename(columns={"level_1": self.sample, 0: "Intensity"}, inplace=True) + # TODO replace intensity either with the more generic term abundance, + # or use what was actually the original name. + # Intensity or LFQ intensity, or even SILAC ratio makes a bit difference + df.rename(columns={"level_1": Cols.SAMPLE, 0: "Intensity"}, inplace=True) if color is not None: - df = df.merge(self.metadata, how="inner", on=[self.sample]) + df = df.merge(self.metadata, how="inner", on=[Cols.SAMPLE]) if method == "violin": fig = px.violin( df, - x=self.sample, + x=Cols.SAMPLE, y="Intensity", color=color, template="simple_white+alphastats_colors", @@ -113,7 +114,7 @@ def plot_sampledistribution( elif method == "box": fig = px.box( df, - x=self.sample, + x=Cols.SAMPLE, y="Intensity", color=color, template="simple_white+alphastats_colors", diff --git a/alphastats/DataSet_Preprocess.py b/alphastats/DataSet_Preprocess.py index 16606c44..acc014f8 100644 --- a/alphastats/DataSet_Preprocess.py +++ b/alphastats/DataSet_Preprocess.py @@ -1,5 +1,5 @@ import logging -from typing import Dict, List, Tuple, Union +from typing import Dict, List, Tuple import numpy as np import pandas as pd @@ -9,10 +9,11 @@ import streamlit as st from sklearn.experimental import enable_iterative_imputer # noqa +from alphastats.keys import Cols, ConstantsClass from alphastats.utils import ignore_warning -class PreprocessingStateKeys: +class PreprocessingStateKeys(metaclass=ConstantsClass): """Keys for accessing the dictionary holding the information about preprocessing.""" # TODO disentangle these keys from the human-readably display strings @@ -45,8 +46,6 @@ def __init__( self, filter_columns: List[str], rawinput: pd.DataFrame, - index_column: str, - sample: str, metadata: pd.DataFrame, preprocessing_info: Dict, mat: pd.DataFrame, @@ -54,8 +53,6 @@ def __init__( self.filter_columns = filter_columns self.rawinput = rawinput - self.index_column = index_column - self.sample = sample self.metadata = metadata self.preprocessing_info = preprocessing_info @@ -89,17 +86,17 @@ def init_preprocessing_info( def _remove_samples(self, sample_list: list): # exclude samples for analysis self.mat = self.mat.drop(sample_list) - self.metadata = self.metadata[~self.metadata[self.sample].isin(sample_list)] + self.metadata = self.metadata[~self.metadata[Cols.SAMPLE].isin(sample_list)] @staticmethod def subset( - mat: pd.DataFrame, metadata: pd.DataFrame, sample: str, preprocessing_info: Dict + mat: pd.DataFrame, metadata: pd.DataFrame, preprocessing_info: Dict ) -> pd.DataFrame: """Filter matrix so only samples that are described in metadata are also found in matrix.""" preprocessing_info.update( {PreprocessingStateKeys.NUM_SAMPLES: metadata.shape[0]} ) - return mat[mat.index.isin(metadata[sample].tolist())] + return mat[mat.index.isin(metadata[Cols.SAMPLE].tolist())] def _remove_na_values(self, cut_off): if ( @@ -154,10 +151,10 @@ def _filter(self): logging.info("Contaminatons have already been filtered.") return - #  print column names with contamination + # print column names with contamination protein_groups_to_remove = self.rawinput[ self.rawinput[self.filter_columns].any(axis=1) - ][self.index_column].tolist() + ][Cols.INDEX].tolist() protein_groups_to_remove = list( set(protein_groups_to_remove) & set(self.mat.columns.to_list()) @@ -351,7 +348,7 @@ def batch_correction(self, batch: str) -> pd.DataFrame: from combat.pycombat import pycombat data = self.mat.transpose() - series_of_batches = self.metadata.set_index(self.sample).reindex( + series_of_batches = self.metadata.set_index(Cols.SAMPLE).reindex( data.columns.to_list() )[batch] @@ -419,6 +416,8 @@ def preprocess( ]: raise ValueError(f"Invalid keyword argument: {k}") + # TODO this is a stateful method as we change self.mat, self.metadata and self.processing_info + # refactor such that it does not change self.mat etc but just return the latest result if remove_contaminations: self._filter() @@ -426,9 +425,7 @@ def preprocess( self._remove_samples(sample_list=remove_samples) if subset: - self.mat = self.subset( - self.mat, self.metadata, self.sample, self.preprocessing_info - ) + self.mat = self.subset(self.mat, self.metadata, self.preprocessing_info) if data_completeness > 0: self._remove_na_values(cut_off=data_completeness) diff --git a/alphastats/DataSet_Statistics.py b/alphastats/DataSet_Statistics.py index bf3a8be7..a0e3f66d 100644 --- a/alphastats/DataSet_Statistics.py +++ b/alphastats/DataSet_Statistics.py @@ -1,11 +1,10 @@ from functools import lru_cache -from typing import Dict, Union +from typing import Dict, Tuple, Union -import numpy as np import pandas as pd import pingouin -from alphastats.DataSet_Preprocess import PreprocessingStateKeys +from alphastats.keys import Cols from alphastats.statistics.Anova import Anova from alphastats.statistics.DifferentialExpressionAnalysis import ( DifferentialExpressionAnalysis, @@ -20,14 +19,10 @@ def __init__( *, mat: pd.DataFrame, metadata: pd.DataFrame, - index_column: str, - sample: str, preprocessing_info: Dict, ): self.mat: pd.DataFrame = mat self.metadata: pd.DataFrame = metadata - self.index_column: str = index_column - self.sample: str = sample self.preprocessing_info: Dict = preprocessing_info @ignore_warning(RuntimeWarning) @@ -65,8 +60,6 @@ def diff_expression_analysis( df = DifferentialExpressionAnalysis( mat=self.mat, metadata=self.metadata, - index_column=self.index_column, - sample=self.sample, preprocessing_info=self.preprocessing_info, group1=group1, group2=group2, @@ -95,8 +88,6 @@ def anova(self, column: str, protein_ids="all", tukey: bool = True) -> pd.DataFr return Anova( mat=self.mat, metadata=self.metadata, - sample=self.sample, - index_column=self.index_column, column=column, protein_ids=protein_ids, tukey=tukey, @@ -126,41 +117,42 @@ def ancova( * ``'p-unc'``: Uncorrected p-values * ``'np2'``: Partial eta-squared """ - df = self.mat[protein_id].reset_index().rename(columns={"index": self.sample}) - df = self.metadata.merge(df, how="inner", on=[self.sample]) + df = self.mat[protein_id].reset_index().rename(columns={"index": Cols.SAMPLE}) + df = self.metadata.merge(df, how="inner", on=[Cols.SAMPLE]) ancova_df = pingouin.ancova(df, dv=protein_id, covar=covar, between=between) return ancova_df - # @ignore_warning(RuntimeWarning) - # def multicova_analysis( # TODO never used outside of tests .. how does this relate to multicova.py? - # self, - # covariates: list, - # n_permutations: int = 3, - # fdr: float = 0.05, - # s0: float = 0.05, - # subset: dict = None, - # ) -> Union[pd.DataFrame, list]: - # """Perform Multicovariat Analysis - # will return a pandas DataFrame with the results and a list of volcano plots (for each covariat) - # - # Args: - # covariates (list): list of covariates, column names in metadata - # n_permutations (int, optional): number of permutations. Defaults to 3. - # fdr (float, optional): False Discovery Rate. Defaults to 0.05. - # s0 (float, optional): . Defaults to 0.05. - # subset (dict, optional): for categorical covariates . Defaults to None. - # - # Returns: - # pd.DataFrame: Multicova Analysis results - # """ - # - # res, plot_list = MultiCovaAnalysis( - # dataset=self, # TODO fix .. does this write to it? - # covariates=covariates, - # n_permutations=n_permutations, - # fdr=fdr, - # s0=s0, - # subset=subset, - # plot=True, - # ).calculate() - # return res, plot_list + @ignore_warning(RuntimeWarning) + def multicova_analysis( # TODO never used outside of tests .. how does this relate to multicova.py? + self, + covariates: list, + n_permutations: int = 3, + fdr: float = 0.05, + s0: float = 0.05, + subset: dict = None, + ) -> Tuple[pd.DataFrame, list]: + """Perform Multicovariat Analysis + will return a pandas DataFrame with the results and a list of volcano plots (for each covariat) + + Args: + covariates (list): list of covariates, column names in metadata + n_permutations (int, optional): number of permutations. Defaults to 3. + fdr (float, optional): False Discovery Rate. Defaults to 0.05. + s0 (float, optional): . Defaults to 0.05. + subset (dict, optional): for categorical covariates . Defaults to None. + + Returns: + pd.DataFrame: Multicova Analysis results + """ + + res, plot_list = MultiCovaAnalysis( + mat=self.mat, + metadata=self.metadata, + covariates=covariates, + n_permutations=n_permutations, + fdr=fdr, + s0=s0, + subset=subset, + ).calculate() + + return res, plot_list diff --git a/alphastats/__init__.py b/alphastats/__init__.py index d25740ed..f8025997 100644 --- a/alphastats/__init__.py +++ b/alphastats/__init__.py @@ -40,11 +40,11 @@ } # TODO get rid of these imports -import alphastats.gui +import alphastats.gui # noqa: F401 -from .cli import * -from .loader.AlphaPeptLoader import * -from .loader.DIANNLoader import * -from .loader.FragPipeLoader import * -from .loader.MaxQuantLoader import * -from .loader.SpectronautLoader import * +from .cli import * # noqa: F403 +from .loader.AlphaPeptLoader import * # noqa: F403 +from .loader.DIANNLoader import * # noqa: F403 +from .loader.FragPipeLoader import * # noqa: F403 +from .loader.MaxQuantLoader import * # noqa: F403 +from .loader.SpectronautLoader import * # noqa: F403 diff --git a/alphastats/cli.py b/alphastats/cli.py index 097cf7c0..ec0998e8 100644 --- a/alphastats/cli.py +++ b/alphastats/cli.py @@ -2,6 +2,4 @@ def run(): - import alphastats - alphastats.gui.gui.run() diff --git a/alphastats/dataset_factory.py b/alphastats/dataset_factory.py index 9e312124..de049cf4 100644 --- a/alphastats/dataset_factory.py +++ b/alphastats/dataset_factory.py @@ -1,10 +1,13 @@ import logging import warnings -from typing import Dict, List, Optional, Tuple, Union +from typing import List, Optional, Tuple, Union import numpy as np import pandas as pd +from alphastats.dataset_harmonizer import DataHarmonizer +from alphastats.keys import Cols + class DataSetFactory: """Create all 'heavy' data structures of a DataSet.""" @@ -13,22 +16,20 @@ def __init__( self, *, rawinput: pd.DataFrame, - index_column: str, intensity_column: Union[List[str], str], metadata_path_or_df: Union[str, pd.DataFrame], - sample_column: str, + data_harmonizer: DataHarmonizer, ): self.rawinput: pd.DataFrame = rawinput - self.sample_column: str = sample_column - self.index_column: str = index_column self.intensity_column: Union[List[str], str] = intensity_column self.metadata_path_or_df: Union[str, pd.DataFrame] = metadata_path_or_df + self._data_harmonizer = data_harmonizer def create_matrix_from_rawinput(self) -> Tuple[pd.DataFrame, pd.DataFrame]: """Creates a matrix: features (Proteins) as columns, samples as rows.""" df = self.rawinput - df = df.set_index(self.index_column) + df = df.set_index(Cols.INDEX) if isinstance(self.intensity_column, str): regex_find_intensity_columns = self.intensity_column.replace( @@ -58,28 +59,27 @@ def _check_matrix_values(mat: pd.DataFrame) -> None: if np.isinf(mat).values.sum() > 0: logging.warning("Data contains infinite values.") - def create_metadata(self, mat: pd.DataFrame) -> Tuple[pd.DataFrame, str]: + def create_metadata(self, mat: pd.DataFrame) -> pd.DataFrame: """Create metadata DataFrame from metadata file or DataFrame.""" if self.metadata_path_or_df is not None: - sample = self.sample_column metadata = self._load_metadata(file_path=self.metadata_path_or_df) - metadata = self._remove_missing_samples_from_metadata(mat, metadata, sample) + metadata = self._data_harmonizer.get_harmonized_metadata(metadata) + metadata = self._remove_missing_samples_from_metadata(mat, metadata) else: - sample = "sample" - metadata = pd.DataFrame({"sample": list(mat.index)}) + metadata = pd.DataFrame({Cols.SAMPLE: list(mat.index)}) - return metadata, sample + return metadata def _remove_missing_samples_from_metadata( - self, mat: pd.DataFrame, metadata: pd.DataFrame, sample + self, mat: pd.DataFrame, metadata: pd.DataFrame ) -> pd.DataFrame: """Remove samples from metadata that are not in the protein data.""" samples_matrix = mat.index.to_list() - samples_metadata = metadata[sample].to_list() + samples_metadata = metadata[Cols.SAMPLE].to_list() misc_samples = list(set(samples_metadata) - set(samples_matrix)) if len(misc_samples) > 0: - metadata = metadata[~metadata[sample].isin(misc_samples)] + metadata = metadata[~metadata[Cols.SAMPLE].isin(misc_samples)] logging.warning( f"{misc_samples} are not described in the protein data and" "are removed from the metadata." @@ -116,12 +116,13 @@ def _load_metadata( ) return None - if df is not None and self.sample_column not in df.columns: - logging.error( - f"sample_column: {self.sample_column} not found in {file_path}" - ) - # check whether sample labeling matches protein data # warnings.warn("WARNING: Sample names do not match sample labelling in protein data") df.columns = df.columns.astype(str) + + # TODO document this + df.drop( + columns=[c for c in df.columns if c.startswith("_IGNORE_")], inplace=True + ) + return df diff --git a/alphastats/dataset_harmonizer.py b/alphastats/dataset_harmonizer.py new file mode 100644 index 00000000..a1f0a9ae --- /dev/null +++ b/alphastats/dataset_harmonizer.py @@ -0,0 +1,59 @@ +"""Harmonize the input data to a common format.""" + +from typing import Dict, Optional + +import pandas as pd + +from alphastats import BaseLoader +from alphastats.keys import Cols + + +class DataHarmonizer: + """Harmonize input data to a common format.""" + + def __init__(self, loader: BaseLoader, sample_column_name: Optional[str] = None): + _rawinput_rename_dict = { + loader.index_column: Cols.INDEX, + } + if loader.gene_names_column is not None: + _rawinput_rename_dict[loader.gene_names_column] = Cols.GENE_NAMES + + self._rawinput_rename_dict = _rawinput_rename_dict + + self._metadata_rename_dict = ( + { + sample_column_name: Cols.SAMPLE, + } + if sample_column_name is not None + else {} + ) + + def get_harmonized_rawinput(self, rawinput: pd.DataFrame) -> pd.DataFrame: + """Harmonize the rawinput data to a common format.""" + return self._get_harmonized_data( + rawinput, + self._rawinput_rename_dict, + ) + + def get_harmonized_metadata(self, metadata: pd.DataFrame) -> pd.DataFrame: + """Harmonize the rawinput data to a common format.""" + return self._get_harmonized_data( + metadata, + self._metadata_rename_dict, + ) + + @staticmethod + def _get_harmonized_data( + input_df: pd.DataFrame, rename_dict: Dict[str, str] + ) -> pd.DataFrame: + """Harmonize data to a common format.""" + for target_name in rename_dict.values(): + if target_name in input_df.columns: + raise ValueError( + f"Column name '{target_name}' already exists. Please rename the column in your input data." + ) + + return input_df.rename( + columns=rename_dict, + errors="raise", + ) diff --git a/alphastats/gui/AlphaPeptStats.py b/alphastats/gui/AlphaPeptStats.py index 182bbff3..6454d0d7 100644 --- a/alphastats/gui/AlphaPeptStats.py +++ b/alphastats/gui/AlphaPeptStats.py @@ -29,9 +29,7 @@ _this_directory = os.path.dirname(_this_file) icon = os.path.join(_this_directory, "alphapeptstats_logo.png") -header_html = img_center + "".format( - img_to_bytes(icon) -) +header_html = img_center + f"" st.markdown( header_html, @@ -60,4 +58,4 @@ # https://discuss.streamlit.io/t/icons-for-the-multi-app-page-menu-in-the-sidebar-other-than-emojis/27222 # https://icons.getbootstrap.com/ # https://medium.com/codex/create-a-multi-page-app-with-the-new-streamlit-option-menu-component-3e3edaf7e7ad -#  https://lightrun.com/answers/streamlit-streamlit-set-multipage-app-emoji-in-stpage_config-not-filename +# https://lightrun.com/answers/streamlit-streamlit-set-multipage-app-emoji-in-stpage_config-not-filename diff --git a/alphastats/gui/__init__.py b/alphastats/gui/__init__.py index dcaea0ee..a4589819 100644 --- a/alphastats/gui/__init__.py +++ b/alphastats/gui/__init__.py @@ -1 +1 @@ -from .gui import run +from .gui import run # noqa: F401 TODO: check if necessary diff --git a/alphastats/gui/example_data/metadata.xlsx b/alphastats/gui/example_data/metadata.xlsx new file mode 100644 index 00000000..f893417b Binary files /dev/null and b/alphastats/gui/example_data/metadata.xlsx differ diff --git a/alphastats/gui/sample_data/proteinGroups.txt b/alphastats/gui/example_data/proteinGroups.txt similarity index 100% rename from alphastats/gui/sample_data/proteinGroups.txt rename to alphastats/gui/example_data/proteinGroups.txt diff --git a/alphastats/gui/gui.py b/alphastats/gui/gui.py index a4e0e8be..a5a7f5e8 100644 --- a/alphastats/gui/gui.py +++ b/alphastats/gui/gui.py @@ -1,7 +1,4 @@ import os -import sys - -from streamlit.web import cli as stcli def run(): @@ -18,22 +15,24 @@ def run(): os.system( f"python -m streamlit run AlphaPeptStats.py --global.developmentMode=false {extra_args}" ) - _this_file = os.path.abspath(__file__) - _this_directory = os.path.dirname(_this_file) - - file_path = os.path.join(_this_directory, "AlphaPeptStats.py") - - HOME = os.path.expanduser("~") - ST_PATH = os.path.join(HOME, ".streamlit") - - for folder in [ST_PATH]: - if not os.path.isdir(folder): - os.mkdir(folder) - - print(f"Starting AlphaPeptStats from {file_path}") - - args = ["streamlit", "run", file_path, "--global.developmentMode=false"] - - sys.argv = args - sys.exit(stcli.main()) + # TODO why are we starting the app a second time here? + # _this_file = os.path.abspath(__file__) + # _this_directory = os.path.dirname(_this_file) + # + # file_path = os.path.join(_this_directory, "AlphaPeptStats.py") + # + # HOME = os.path.expanduser("~") + # ST_PATH = os.path.join(HOME, ".streamlit") + # + # for folder in [ST_PATH]: + # if not os.path.isdir(folder): + # os.mkdir(folder) + # + # print(f"Starting AlphaPeptStats from {file_path}") + # + # args = ["streamlit", "run", file_path, "--global.developmentMode=false"] + # + # sys.argv = args + # + # sys.exit(stcli.main()) diff --git a/alphastats/gui/pages/02_Import Data.py b/alphastats/gui/pages/02_Import Data.py index be191f64..97402f22 100644 --- a/alphastats/gui/pages/02_Import Data.py +++ b/alphastats/gui/pages/02_Import Data.py @@ -1,12 +1,8 @@ -from typing import List - import streamlit as st -from alphastats import BaseLoader from alphastats.DataSet import DataSet from alphastats.gui.utils.import_helper import ( load_example_data, - load_options, load_proteomics_data, show_button_download_metadata_template_file, show_loader_columns_selection, @@ -23,26 +19,23 @@ def _finalize_data_loading( - loader: BaseLoader, - metadata_columns: List[str], dataset: DataSet, ) -> None: """Finalize the data loading process.""" - st.session_state[StateKeys.LOADER] = ( - loader # TODO: Figure out if we even need the loader here, as the dataset has the loader as an attribute. - ) - st.session_state[StateKeys.METADATA_COLUMNS] = metadata_columns st.session_state[StateKeys.DATASET] = dataset - load_options() sidebar_info() st.page_link("pages/03_Data Overview.py", label="=> Go to data overview page..") +st.set_page_config(layout="wide") init_session_state() sidebar_info() +st.markdown("## Import Data") + + st.markdown("### Start a new session") st.write( "Start a new session will discard the current one (including all analysis!) and enable importing a new dataset." @@ -58,15 +51,15 @@ def _finalize_data_loading( if c2.button("Start new Session with example DataSet", key="_load_example_data"): empty_session_state() init_session_state() - loader, metadata_columns, dataset = load_example_data() + dataset = load_example_data() - _finalize_data_loading(loader, metadata_columns, dataset) + _finalize_data_loading(dataset) st.stop() st.markdown("### Import Proteomics Data") if StateKeys.DATASET in st.session_state: - st.info(f"DataSet already present.") + st.info("DataSet already present.") st.page_link("pages/03_Data Overview.py", label="=> Go to data overview page..") st.stop() @@ -126,26 +119,23 @@ def _finalize_data_loading( "Upload metadata file with information about your samples", ) -if metadatafile_upload is None: - st.stop() - -metadatafile_df = uploaded_file_to_df(metadatafile_upload) +metadatafile_df = None +if metadatafile_upload is not None: + metadatafile_df = uploaded_file_to_df(metadatafile_upload) -sample_column = show_select_sample_column_for_metadata( - metadatafile_df, software, loader -) + sample_column = show_select_sample_column_for_metadata( + metadatafile_df, software, loader + ) # ########## Create dataset st.markdown("##### 4. Create DataSet") dataset = None -metadata_columns = [] c1, c2 = st.columns(2) if c2.button("Create DataSet without metadata"): dataset = DataSet(loader=loader) - metadata_columns = ["sample"] if c1.button( "Create DataSet with metadata", @@ -162,8 +152,7 @@ def _finalize_data_loading( metadata_path_or_df=metadatafile_df, sample_column=sample_column, ) - metadata_columns = metadatafile_df.columns.to_list() if dataset is not None: - st.info("DataSet has been created.") - _finalize_data_loading(loader, metadata_columns, dataset) + st.toast(" DataSet has been created.", icon="✅") + _finalize_data_loading(dataset) diff --git a/alphastats/gui/pages/03_Data Overview.py b/alphastats/gui/pages/03_Data Overview.py index dfa1bfc1..d58792ab 100644 --- a/alphastats/gui/pages/03_Data Overview.py +++ b/alphastats/gui/pages/03_Data Overview.py @@ -9,18 +9,19 @@ ) from alphastats.gui.utils.ui_helper import StateKeys, init_session_state, sidebar_info +st.set_page_config(layout="wide") init_session_state() sidebar_info() +st.markdown("## Data Overview") + if StateKeys.DATASET not in st.session_state: - st.info("Import Data first") + st.info("Import data first.") st.stop() -st.markdown("### DataSet Info") - display_loaded_dataset(st.session_state[StateKeys.DATASET]) -st.markdown("## DataSet overview") +st.markdown("### Intensities") c1, c2 = st.columns(2) diff --git a/alphastats/gui/pages/03_Preprocessing.py b/alphastats/gui/pages/03_Preprocessing.py index d1a668a9..9b98f61c 100644 --- a/alphastats/gui/pages/03_Preprocessing.py +++ b/alphastats/gui/pages/03_Preprocessing.py @@ -1,8 +1,7 @@ -import pandas as pd import streamlit as st +from alphastats.DataSet_Preprocess import PreprocessingStateKeys from alphastats.gui.utils.preprocessing_helper import ( - PREPROCESSING_STEPS, configure_preprocessing, display_preprocessing_info, draw_workflow, @@ -12,47 +11,56 @@ ) from alphastats.gui.utils.ui_helper import StateKeys, init_session_state, sidebar_info +st.set_page_config(layout="wide") init_session_state() sidebar_info() -if StateKeys.WORKFLOW not in st.session_state: - st.session_state[StateKeys.WORKFLOW] = [ - PREPROCESSING_STEPS.REMOVE_CONTAMINATIONS, - PREPROCESSING_STEPS.SUBSET, - PREPROCESSING_STEPS.LOG2_TRANSFORM, - ] -st.markdown("### Preprocessing") -c1, c2 = st.columns([1, 1]) +st.markdown("## Preprocessing") -with c2: - if StateKeys.DATASET in st.session_state: - settings = configure_preprocessing(dataset=st.session_state[StateKeys.DATASET]) - new_workflow = update_workflow(settings) - if new_workflow != st.session_state[StateKeys.WORKFLOW]: - st.session_state[StateKeys.WORKFLOW] = new_workflow +if StateKeys.DATASET not in st.session_state: + st.info("Import data first.") + st.stop() + +c1, _, c2 = st.columns([0.3, 0.1, 0.45]) + + +dataset = st.session_state[StateKeys.DATASET] with c1: - st.write("#### Flowchart of preprocessing workflow:") + st.write("##### Select preprocessing steps") + settings = configure_preprocessing(dataset=dataset) + new_workflow = update_workflow(settings) + if new_workflow != st.session_state[StateKeys.WORKFLOW]: + st.session_state[StateKeys.WORKFLOW] = new_workflow - selected_nodes = draw_workflow(st.session_state[StateKeys.WORKFLOW]) + is_preprocessing_done = dataset.preprocessing_info[ + PreprocessingStateKeys.PREPROCESSING_DONE + ] + + if is_preprocessing_done: + st.success("Preprocessing finished successfully!", icon="✅") + + c11, c12 = st.columns([1, 1]) + if c11.button( + "Run preprocessing", key="_run_preprocessing", disabled=is_preprocessing_done + ): + run_preprocessing(settings, dataset) + st.rerun() - if StateKeys.DATASET not in st.session_state: - st.info("Import data first to configure and run preprocessing") + if c12.button( + "❌ Reset preprocessing", + key="_reset_preprocessing", + disabled=not is_preprocessing_done, + ): + reset_preprocessing(dataset) + st.rerun() - else: - c11, c12 = st.columns([1, 1]) - if c11.button("Run preprocessing", key="_run_preprocessing"): - run_preprocessing(settings, st.session_state[StateKeys.DATASET]) - # TODO show more info about the preprocessing steps - display_preprocessing_info( - st.session_state[StateKeys.DATASET].preprocessing_info - ) +with c2: + selected_nodes = draw_workflow(st.session_state[StateKeys.WORKFLOW]) - if c12.button("Reset all Preprocessing steps", key="_reset_preprocessing"): - reset_preprocessing(st.session_state[StateKeys.DATASET]) - display_preprocessing_info( - st.session_state[StateKeys.DATASET].preprocessing_info - ) + st.markdown("##### Current preprocessing status") + display_preprocessing_info(dataset.preprocessing_info) +# TODO add help to individual steps with more info # TODO: Add comparison plot of intensity distribution before and after preprocessing diff --git a/alphastats/gui/pages/04_Analysis.py b/alphastats/gui/pages/04_Analysis.py index 97f33895..7556e71c 100644 --- a/alphastats/gui/pages/04_Analysis.py +++ b/alphastats/gui/pages/04_Analysis.py @@ -1,142 +1,101 @@ import streamlit as st +from alphastats.gui.utils.analysis import PlottingOptions, StatisticOptions from alphastats.gui.utils.analysis_helper import ( - display_df, - display_figure, - download_figure, - download_preprocessing_info, - get_analysis, - load_options, - save_plot_to_session_state, + display_analysis_result_with_buttons, + gather_parameters_and_do_analysis, ) from alphastats.gui.utils.ui_helper import ( StateKeys, - convert_df, init_session_state, sidebar_info, ) - -def select_analysis(): - """ - select box - loads keys from option dicts - """ - load_options() - method = st.selectbox( - "Analysis", - options=list(st.session_state[StateKeys.PLOTTING_OPTIONS].keys()) - + list(st.session_state[StateKeys.STATISTIC_OPTIONS].keys()), - ) - return method - - +st.set_page_config(layout="wide") init_session_state() sidebar_info() -st.markdown("### Analysis") +st.markdown("## Analysis") # set background to white so downloaded pngs dont have grey background -styl = f""" +styl = """ """ st.markdown(styl, unsafe_allow_html=True) +# TODO use caching functionality for all analysis (not: plot creation) -if StateKeys.PLOT_LIST not in st.session_state: - st.session_state[StateKeys.PLOT_LIST] = [] - - -if StateKeys.DATASET in st.session_state: - c1, c2 = st.columns((1, 2)) - - plot_to_display = False - df_to_display = False - method_plot = None - - with c1: - method = select_analysis() - - if method in st.session_state[StateKeys.PLOTTING_OPTIONS]: - analysis_result = get_analysis( - method=method, options_dict=st.session_state[StateKeys.PLOTTING_OPTIONS] - ) - plot_to_display = True - - elif method in st.session_state[StateKeys.STATISTIC_OPTIONS]: - analysis_result = get_analysis( - method=method, - options_dict=st.session_state[StateKeys.STATISTIC_OPTIONS], - ) - df_to_display = True - - st.markdown("") - st.markdown("") - st.markdown("") - st.markdown("") - - with c2: - # --- Plot ------------------------------------------------------- +if StateKeys.DATASET not in st.session_state: + st.info("Import data first.") + st.stop() - if analysis_result is not None and method != "Clustermap" and plot_to_display: - display_figure(analysis_result) +# --- SELECTION ------------------------------------------------------- +show_plot = False +show_df = False +analysis_result = None - save_plot_to_session_state(analysis_result, method) - method_plot = [method, analysis_result] - - elif method == "Clustermap": - st.write("Download Figure to see full size.") - - display_figure(analysis_result) - - save_plot_to_session_state(analysis_result, method) - - # --- STATISTICAL ANALYSIS ------------------------------------------------------- - - elif analysis_result is not None and df_to_display: - display_df(analysis_result) - - filename = method + ".csv" - csv = convert_df(analysis_result) - - if analysis_result is not None and method != "Clustermap" and plot_to_display: - col1, col2, col3 = st.columns([1, 1, 1]) +c1, c2 = st.columns([0.33, 0.67]) +with c1: + plotting_options = PlottingOptions.get_values() + statistic_options = StatisticOptions.get_values() + analysis_method = st.selectbox( + "Analysis", + options=["" + custom_group_option = "Custom groups from samples .." + + grouping_variable = st.selectbox( + "Grouping variable", + options=[default_option] + + metadata.columns.to_list() + + [custom_group_option], + ) + + column = None + if grouping_variable == default_option: + group1 = st.selectbox("Group 1", options=[]) + group2 = st.selectbox("Group 2", options=[]) + + elif grouping_variable != custom_group_option: + unique_values = metadata[grouping_variable].unique().tolist() + + column = grouping_variable + group1 = st.selectbox("Group 1", options=unique_values) + group2 = st.selectbox("Group 2", options=list(reversed(unique_values))) + + else: + group1 = st.multiselect( + "Group 1 samples:", + options=metadata[Cols.SAMPLE].to_list(), + ) + + group2 = st.multiselect( + "Group 2 samples:", + options=list(reversed(metadata[Cols.SAMPLE].to_list())), + ) + + intersection_list = list(set(group1).intersection(set(group2))) + if len(intersection_list) > 0: + st.warning( + "Group 1 and Group 2 contain same samples: " + + str(intersection_list) + ) + + self._parameters.update({"group1": group1, "group2": group2}) + if column is not None: + self._parameters["column"] = column + + def _pre_analysis_check(self): + """Raise if selected groups are different.""" + if self._parameters["group1"] == self._parameters["group2"]: + raise ( + ValueError( + "Group 1 and Group 2 can not be the same. Please select different groups." + ) + ) + + +class AbstractDimensionReductionAnalysis(AbstractAnalysis, ABC): + """Abstract class for dimension reduction analysis widgets.""" + + def show_widget(self): + """Gather parameters for dimension reduction analysis.""" + + group = st.selectbox( + "Color according to", + options=[None] + self._dataset.metadata.columns.to_list(), + ) + + circle = st.checkbox("circle") + + self._parameters.update({"circle": circle, "group": group}) + + +class AbstractIntensityPlot(AbstractAnalysis, ABC): + """Abstract class for intensity plot analysis widgets.""" + + def show_widget(self): + """Gather parameters for intensity plot analysis.""" + + group = st.selectbox( + "Color according to", + options=[None] + self._dataset.metadata.columns.to_list(), + ) + method = st.selectbox( + "Plot layout", + options=["violin", "box", "scatter"], + ) + + self._parameters.update({"group": group, "method": method}) + + +class IntensityPlot(AbstractIntensityPlot, ABC): + """Abstract class for intensity plot analysis widgets.""" + + def show_widget(self): + """Gather parameters for intensity plot analysis.""" + super().show_widget() + + protein_id = st.selectbox( + "ProteinID/ProteinGroup", + options=self._dataset.mat.columns.to_list(), + ) + + self._parameters.update({"protein_id": protein_id}) + + def _do_analysis(self): + """Draw Intensity Plot using the IntensityPlot class.""" + intensity_plot = self._dataset.plot_intensity( + protein_id=self._parameters["protein_id"], + method=self._parameters["method"], + group=self._parameters["group"], + ) + return intensity_plot, None + + +class SampleDistributionPlot(AbstractIntensityPlot, ABC): + """Abstract class for sampledistribution_plot analysis widgets.""" + + def _do_analysis(self): + """Draw Intensity Plot using the IntensityPlot class.""" + intensity_plot = self._dataset.plot_sampledistribution( + method=self._parameters["method"], + color=self._parameters["group"], # no typo + ) + return intensity_plot, None + + +class PCAPlotAnalysis(AbstractDimensionReductionAnalysis): + """Widget for PCA Plot analysis.""" + + def _do_analysis(self): + """Draw PCA Plot using the PCAPlot class.""" + + pca_plot = self._dataset.plot_pca( + group=self._parameters["group"], + circle=self._parameters["circle"], + ) + return pca_plot, None + + +class UMAPPlotAnalysis(AbstractDimensionReductionAnalysis): + """Widget for UMAP Plot analysis.""" + + def _do_analysis(self): + """Draw PCA Plot using the PCAPlot class.""" + umap_plot = self._dataset.plot_umap( + group=self._parameters["group"], + circle=self._parameters["circle"], + ) + return umap_plot, None + + +class TSNEPlotAnalysis(AbstractDimensionReductionAnalysis): + """Widget for t-SNE Plot analysis.""" + + def show_widget(self): + """Show the widget and gather parameters.""" + super().show_widget() + + n_iter = st.select_slider( + "Maximum number of iterations for the optimization", + range(250, 2001), + value=1000, + ) + perplexity = st.select_slider("Perplexity", range(5, 51), value=30) + + self._parameters.update( + { + "n_iter": n_iter, + "perplexity": perplexity, + } + ) + + def _do_analysis(self): + """Draw t-SNE Plot using the TSNEPlot class.""" + tsne_plot = self._dataset.plot_tsne( + group=self._parameters["group"], + circle=self._parameters["circle"], + perplexity=self._parameters["perplexity"], + n_iter=self._parameters["n_iter"], + ) + return tsne_plot, None + + +class VolcanoPlotAnalysis(AbstractGroupCompareAnalysis): + """Widget for Volcano Plot analysis.""" + + def show_widget(self): + """Show the widget and gather parameters.""" + super().show_widget() + + parameters = {} + method = st.selectbox( + "Differential Analysis using:", + options=["ttest", "anova", "wald", "sam", "paired-ttest", "welch-ttest"], + ) + parameters["method"] = method + + parameters["labels"] = st.checkbox("Add labels", value=True) + + parameters["draw_line"] = st.checkbox("Draw lines", value=True) + + parameters["alpha"] = st.number_input( + label="alpha", min_value=0.001, max_value=0.050, value=0.050 + ) + + parameters["min_fc"] = st.select_slider( + "Foldchange cutoff", range(0, 3), value=1 + ) + + # TODO: The sam fdr cutoff should be mutually exclusive with alpha + if method == "sam": + parameters["perm"] = st.number_input( + label="Number of Permutations", min_value=1, max_value=1000, value=10 + ) + parameters["fdr"] = st.number_input( + label="FDR cut off", min_value=0.005, max_value=0.1, value=0.050 + ) + + self._parameters.update(parameters) + + def _do_analysis(self): + """Draw Volcano Plot using the VolcanoPlot class. + + Returns a tuple(figure, analysis_object, parameters) where figure is the plot, + analysis_object is the underlying object, parameters is a dictionary of the parameters used. + """ + # Note that currently, values that are not set by they UI would still be passed as None to the VolcanoPlot class, + # thus overwriting the default values set therein. + # If we introduce optional parameters in the UI, either use `inspect` to get the defaults from the class, + # or refactor it so that all default values are `None` and the class sets the defaults programmatically. + volcano_plot = VolcanoPlot( + mat=self._dataset.mat, + rawinput=self._dataset.rawinput, + metadata=self._dataset.metadata, + preprocessing_info=self._dataset.preprocessing_info, + group1=self._parameters["group1"], + group2=self._parameters["group2"], + column=self._parameters["column"], + method=self._parameters["method"], + labels=self._parameters["labels"], + min_fc=self._parameters["min_fc"], + alpha=self._parameters["alpha"], + draw_line=self._parameters["draw_line"], + perm=self._parameters["perm"], + fdr=self._parameters["fdr"], + color_list=self._parameters["color_list"], + ) + # TODO currently there's no other way to obtain both the plot and the underlying data + # Should be refactored such that the interface provided by DateSet.plot_volcano() is used + # One option could be to always return the whole analysis object. + + return volcano_plot.plot, volcano_plot + + +class ClustermapAnalysis(AbstractAnalysis): + """Widget for Clustermap analysis.""" + + _works_with_nans = False + + def _do_analysis(self): + """Draw Clustermap using the Clustermap class.""" + clustermap = self._dataset.plot_clustermap() + return clustermap, None + + +class DendrogramAnalysis(AbstractAnalysis): + """Widget for Dendrogram analysis.""" + + _works_with_nans = False + + def _do_analysis(self): + """Draw Clustermap using the Clustermap class.""" + dendrogram = self._dataset.plot_dendrogram() + return dendrogram, None + + +class DifferentialExpressionAnalysis(AbstractGroupCompareAnalysis): + """Widget for differential expression analysis.""" + + def show_widget(self): + """Show the widget and gather parameters.""" + + method = st.selectbox( + "Differential Analysis using:", + options=["ttest", "wald"], + ) + + if method == "wald": + self._works_with_nans = False + + super().show_widget() + + self._parameters.update({"method": method}) + + def _do_analysis(self): + """Perform T-test analysis.""" + diff_exp_analysis = self._dataset.diff_expression_analysis( + method=self._parameters["method"], + group1=self._parameters["group1"], + group2=self._parameters["group2"], + column=self._parameters["column"], + ) + return diff_exp_analysis, None + + +class TukeyTestAnalysis(AbstractAnalysis): + """Widget for Tukey-Test analysis.""" + + def show_widget(self): + """Show the widget and gather parameters.""" + + protein_id = st.selectbox( + "ProteinID/ProteinGroup", + options=self._dataset.mat.columns.to_list(), + ) + group = st.selectbox( + "A metadata variable to calculate a pairwise tukey test", + options=self._dataset.metadata.columns.to_list(), + ) + self._parameters.update({"protein_id": protein_id, "group": group}) + + def _do_analysis(self): + """Perform Tukey-test analysis.""" + tukey_test_analysis = self._dataset.tukey_test( + protein_id=self._parameters["protein_id"], + group=self._parameters["group"], + ) + return tukey_test_analysis, None + + +class AnovaAnalysis(AbstractGroupCompareAnalysis): + """Widget for ANOVA analysis.""" + + def show_widget(self): + """Show the widget and gather parameters.""" + + column = st.selectbox( + "A variable from the metadata to calculate ANOVA", + options=self._dataset.metadata.columns.to_list(), + ) + protein_ids = st.selectbox( + "All ProteinIDs/or specific ProteinID to perform ANOVA", + options=["all"] + self._dataset.mat.columns.to_list(), + ) + + tukey = st.checkbox("Follow-up Tukey") + + self._parameters.update( + {"column": column, "protein_ids": protein_ids, "tukey": tukey} + ) + + def _do_analysis(self): + """Perform Anova analysis.""" + anova_analysis = self._dataset.anova( + column=self._parameters["column"], + protein_ids=self._parameters["protein_ids"], + tukey=self._parameters["tukey"], + ) + return anova_analysis, None + + +class AncovaAnalysis(AbstractAnalysis): + """Widget for Ancova analysis.""" + + def show_widget(self): + """Show the widget and gather parameters.""" + + protein_id = st.selectbox( + "ProteinID/ProteinGroup", + options=self._dataset.mat.columns.to_list(), + ) + covar = st.selectbox( + "Name(s) of column(s) in metadata with the covariate.", + options=self._dataset.metadata.columns.to_list(), + ) # TODO: why plural if only one can be selected? + between = st.selectbox( + "Name of the column in the metadata with the between factor.", + options=self._dataset.metadata.columns.to_list(), + ) + + self._parameters.update( + {"protein_id": protein_id, "covar": covar, "between": between} + ) + + def _do_analysis(self): + """Perform ANCOVA analysis.""" + ancova_analysis = self._dataset.ancova( + protein_id=self._parameters["protein_id"], + covar=self._parameters["covar"], + between=self._parameters["between"], + ) + return ancova_analysis, None + + +ANALYSIS_OPTIONS = { + PlottingOptions.VOLCANO_PLOT: VolcanoPlotAnalysis, + PlottingOptions.PCA_PLOT: PCAPlotAnalysis, + PlottingOptions.UMAP_PLOT: UMAPPlotAnalysis, + PlottingOptions.TSNE_PLOT: TSNEPlotAnalysis, + PlottingOptions.SAMPLE_DISTRIBUTION_PLOT: SampleDistributionPlot, + PlottingOptions.INTENSITY_PLOT: IntensityPlot, + PlottingOptions.CLUSTERMAP: ClustermapAnalysis, + PlottingOptions.DENDROGRAM: DendrogramAnalysis, + StatisticOptions.DIFFERENTIAL_EXPRESSION: DifferentialExpressionAnalysis, + StatisticOptions.TUKEY_TEST: TukeyTestAnalysis, + StatisticOptions.ANOVA: AnovaAnalysis, + StatisticOptions.ANCOVA: AncovaAnalysis, +} diff --git a/alphastats/gui/utils/analysis_helper.py b/alphastats/gui/utils/analysis_helper.py index 78ca1967..46c5bcb6 100644 --- a/alphastats/gui/utils/analysis_helper.py +++ b/alphastats/gui/utils/analysis_helper.py @@ -1,400 +1,198 @@ import io +from typing import Any, Callable, Dict, Optional, Tuple, Union import pandas as pd import streamlit as st -from alphastats.gui.utils.ui_helper import StateKeys, convert_df -from alphastats.plots.VolcanoPlot import VolcanoPlot +from alphastats.gui.utils.analysis import ( + ANALYSIS_OPTIONS, + PlottingOptions, + StatisticOptions, +) +from alphastats.gui.utils.ui_helper import ( + StateKeys, + show_button_download_df, +) +from alphastats.plots.PlotUtils import PlotlyObject + + +@st.fragment +def display_analysis_result_with_buttons( + df: pd.DataFrame, + analysis_method: str, + parameters: Optional[Dict], + show_save_button=True, + name: str = None, +) -> None: + """A fragment to display a statistical analysis and download options.""" + + if analysis_method in PlottingOptions.get_values(): + display_function = display_figure + download_function = _show_buttons_download_figure + elif analysis_method in StatisticOptions.get_values(): + display_function = _display_df + download_function = show_button_download_df + else: + raise ValueError(f"Analysis method {analysis_method} not found.") + + _display( + df, + analysis_method=analysis_method, + parameters=parameters, + show_save_button=show_save_button, + name=name, + display_function=display_function, + download_function=download_function, + ) -def check_if_options_are_loaded(f): - # decorator to check for missing values - # TODO remove this - def inner(*args, **kwargs): - if hasattr(st.session_state, StateKeys.PLOTTING_OPTIONS) is False: - load_options() +def _display( + analysis_result: Union[PlotlyObject, pd.DataFrame], + *, + analysis_method: str, + display_function: Callable, + download_function: Callable, + parameters: Dict, + name: str, + show_save_button: bool, +) -> None: + """Display analysis results and download options.""" + display_function(analysis_result) + + c1, c2, c3 = st.columns([1, 1, 1]) + + if name is None: + name = analysis_method + + name_pretty = name.replace(" ", "_").lower() + with c1: + if show_save_button and st.button("Save to results page.."): + _save_analysis_to_session_state( + analysis_result, analysis_method, parameters + ) + st.toast("Saved to results page!", icon="✅") - return f(*args, **kwargs) + with c2: + download_function( + analysis_result, + name_pretty, + ) - return inner + with c3: + _show_button_download_analysis_and_preprocessing_info( + analysis_method, analysis_result, parameters, name_pretty + ) -def display_figure(plot): - """ - display plotly or seaborn figure - """ +def display_figure(plot: PlotlyObject) -> None: + """Display plotly or seaborn figure.""" try: - st.plotly_chart(plot.update_layout(plot_bgcolor="white")) - except: + st.plotly_chart() + except Exception: st.pyplot(plot) -def save_plot_to_session_state(plot, method): - """ - save plot with method to session state to retrieve old results - """ - st.session_state[StateKeys.PLOT_LIST] += [(method, plot)] - - -def display_df(df): - mask = df.applymap(type) != bool - d = {True: "TRUE", False: "FALSE"} - df = df.where(mask, df.replace(d)) - st.dataframe(df) +def _show_buttons_download_figure(analysis_result: PlotlyObject, name: str) -> None: + """Show buttons to download figure as .pdf or .svg.""" + # TODO We have to check for all scatter plotly figures, which renderer they use. + # Default is webgl, which is good for browser performance, but looks horrendous in svg download + # rerendering with svg as renderer could be a method of PlotlyObject to invoke prior to saving as svg + _show_button_download_figure(analysis_result, name, "pdf") + _show_button_download_figure(analysis_result, name, "svg") -def download_figure(obj, format, plotting_library="plotly"): - """ - download plotly figure - """ - - plot = obj[1] - filename = obj[0] + "." + format +def _show_button_download_figure( + plot: PlotlyObject, + file_name: str, + file_format: str, +) -> None: + """Show a button to download a figure.""" buffer = io.BytesIO() - if plotting_library == "plotly": - # Save the figure as a pdf to the buffer - plot.write_image(file=buffer, format=format) + try: # plotly + plot.write_image(file=buffer, format=file_format) + except AttributeError: # seaborn + plot.savefig(buffer, format=file_format) - else: - plot.savefig(buffer, format=format) - - st.download_button(label="Download as " + format, data=buffer, file_name=filename) - - -def download_preprocessing_info(plot): - preprocesing_dict = plot[1].preprocessing - df = pd.DataFrame(preprocesing_dict.items()) - filename = "plot" + plot[0] + "preprocessing_info.csv" - csv = convert_df(df) st.download_button( - "Download DataSet Info as .csv", - csv, - filename, - "text/csv", - key="preprocessing", + label="Download as ." + file_format, + data=buffer, + file_name=file_name + "." + file_format, ) -def get_unique_values_from_column(column): - unique_values = ( - st.session_state[StateKeys.DATASET].metadata[column].unique().tolist() - ) - return unique_values - - -def st_general(method_dict): - chosen_parameter_dict = {} - - if "settings" in list(method_dict.keys()): - settings_dict = method_dict.get("settings") - - for parameter in settings_dict: - parameter_dict = settings_dict[parameter] - - if "options" in parameter_dict: - chosen_parameter = st.selectbox( - parameter_dict.get("label"), options=parameter_dict.get("options") - ) - else: - chosen_parameter = st.checkbox(parameter_dict.get("label")) - - chosen_parameter_dict[parameter] = chosen_parameter - - submitted = st.button("Submit") - - if submitted: - with st.spinner("Calculating..."): - return method_dict["function"](**chosen_parameter_dict) +# TODO: use pandas stylers, rather than changing the data +def _display_df(df: pd.DataFrame) -> None: + """Display a dataframe.""" + mask = df.applymap(type) != bool # noqa: E721 + df = df.where(mask, df.replace({True: "TRUE", False: "FALSE"})) + st.dataframe(df) -# @st.cache_data # TODO check if caching is sensible here and if so, reimplement with dataset-hash -def gui_volcano_plot_differential_expression_analysis( - chosen_parameter_dict, +def _show_button_download_analysis_and_preprocessing_info( + method: str, + analysis_result: Union[PlotlyObject, pd.DataFrame], + parameters: Dict, + name: str, ): - """ - initalize volcano plot object with differential expression analysis results - """ - dataset = st.session_state[StateKeys.DATASET] - - # TODO this is just a quickfix, a simple interface needs to be provided by DataSet - volcano_plot = VolcanoPlot( - mat=dataset.mat, - rawinput=dataset.rawinput, - metadata=dataset.metadata, - sample=dataset.sample, - index_column=dataset.index_column, - gene_names=dataset._gene_names, - preprocessing_info=dataset.preprocessing_info, - **chosen_parameter_dict, - plot=False, - ) - volcano_plot._perform_differential_expression_analysis() - volcano_plot._add_hover_data_columns() - return volcano_plot - - -def gui_volcano_plot(): - """ - Draw Volcano Plot using the VolcanoPlot class - """ - chosen_parameter_dict = helper_compare_two_groups() - method = st.selectbox( - "Differential Analysis using:", - options=["ttest", "anova", "wald", "sam", "paired-ttest", "welch-ttest"], - ) - chosen_parameter_dict.update({"method": method}) - - # TODO streamlit doesnt allow nested columns check for updates - - labels = st.checkbox("Add label") - - draw_line = st.checkbox("Draw line") - - alpha = st.number_input( - label="alpha", min_value=0.001, max_value=0.050, value=0.050 - ) - - min_fc = st.select_slider("Foldchange cutoff", range(0, 3), value=1) - - plotting_parameter_dict = { - "labels": labels, - "draw_line": draw_line, - "alpha": alpha, - "min_fc": min_fc, + """Download analysis info (= analysis and preprocessing parameters) as .csv.""" + parameters_pretty = { + f"analysis_parameter__{k}": "None" if v is None else v + for k, v in parameters.items() } - if method == "sam": - perm = st.number_input( - label="Number of Permutations", min_value=1, max_value=1000, value=10 - ) - fdr = st.number_input( - label="FDR cut off", min_value=0.005, max_value=0.1, value=0.050 - ) - chosen_parameter_dict.update({"perm": perm, "fdr": fdr}) - - submitted = st.button("Submit") - - if submitted: - # TODO this seems not be covered by unit test - volcano_plot = gui_volcano_plot_differential_expression_analysis( - chosen_parameter_dict - ) - volcano_plot._update(plotting_parameter_dict) - volcano_plot._annotate_result_df() - volcano_plot._plot() - return volcano_plot.plot - - -def get_analysis_options_from_dict(method, options_dict): - """ - extract plotting options from dict amd display as selectbox or - give selceted options to plotting function - """ - - method_dict = options_dict.get(method) - - if method == "t-SNE Plot": - return st_tsne_options(method_dict) - - elif method == "Differential Expression Analysis - T-test": - return st_calculate_ttest(method=method, options_dict=options_dict) - - elif method == "Differential Expression Analysis - Wald-test": - return st_calculate_waldtest(method=method, options_dict=options_dict) - - elif method == "Volcano Plot": - return gui_volcano_plot() - - elif method == "PCA Plot": - return st_plot_pca(method_dict) - - elif method == "UMAP Plot": - return st_plot_umap(method_dict) - - elif "settings" not in method_dict: - if st.session_state[StateKeys.DATASET].mat.isna().values.any() == True: - st.error( - "Data contains missing values impute your data before plotting (Preprocessing - Imputation)." - ) - return - - chosen_parameter_dict = {} - return method_dict["function"](**chosen_parameter_dict) - + if method in PlottingOptions.get_values(): + dict_to_save = { + **analysis_result.preprocessing, + **parameters_pretty, + } # TODO why is the preprocessing info saved in the plots? else: - return st_general(method_dict=method_dict) - + dict_to_save = parameters_pretty -def st_plot_pca(method_dict): - chosen_parameter_dict = helper_plot_dimensionality_reduction( - method_dict=method_dict + show_button_download_df( + pd.DataFrame(dict_to_save.items()), + file_name=f"analysis_info__{name}", + label="Download analysis info as .csv", ) - submitted = st.button("Submit") - - if submitted: - with st.spinner("Calculating..."): - return method_dict["function"](**chosen_parameter_dict) - - -def st_plot_umap(method_dict): - chosen_parameter_dict = helper_plot_dimensionality_reduction( - method_dict=method_dict - ) - - submitted = st.button("Submit") - - if submitted: - with st.spinner("Calculating..."): - return method_dict["function"](**chosen_parameter_dict) - - -def st_calculate_ttest(method, options_dict): - """ - perform ttest in streamlit - """ - chosen_parameter_dict = helper_compare_two_groups() - chosen_parameter_dict.update({"method": "ttest"}) - - submitted = st.button("Submit") - - if submitted: - with st.spinner("Calculating..."): - return options_dict.get(method)["function"](**chosen_parameter_dict) - - -def st_calculate_waldtest(method, options_dict): - chosen_parameter_dict = helper_compare_two_groups() - chosen_parameter_dict.update({"method": "wald"}) - - submitted = st.button("Submit") - - if submitted: - with st.spinner("Calculating..."): - return options_dict.get(method)["function"](**chosen_parameter_dict) - -def helper_plot_dimensionality_reduction(method_dict): - group = st.selectbox( - method_dict["settings"]["group"].get("label"), - options=method_dict["settings"]["group"].get("options"), - ) +def _save_analysis_to_session_state( + analysis_results: Union[PlotlyObject, pd.DataFrame], + method: str, + parameters: Dict, +): + """Save analysis with method and parameters to session state to show on results page.""" + st.session_state[StateKeys.ANALYSIS_LIST] += [ + ( + analysis_results, + method, + parameters, + ) + ] - circle = False - if group is not None: - circle = st.checkbox("circle") +def gather_parameters_and_do_analysis( + analysis_method: str, +) -> Tuple[ + Optional[Union[PlotlyObject, pd.DataFrame]], Optional[Any], Optional[Dict[str, Any]] +]: + """Extract plotting options and display. - chosen_parameter_dict = { - "circle": circle, - "group": group, - } - return chosen_parameter_dict + Returns a tuple(figure, analysis_object, parameters) where figure is the plot, + analysis_object is the underlying object, parameters is a dictionary of the parameters used. - -def helper_compare_two_groups(): - """ - Helper function to compare two groups for example - Volcano Plot, Differetial Expression Analysis and t-test - selectbox based on selected column + Currently, analysis_object is only not-None for Volcano Plot. """ - - chosen_parameter_dict = {} - default_option = "