Skip to content

Commit

Permalink
Merge pull request #691 from Geeks-Sid/fix-collectstats
Browse files Browse the repository at this point in the history
Fixed plotting function for final stats
  • Loading branch information
sarthakpati authored Jul 14, 2023
2 parents 687710b + 0918202 commit 8a8e437
Showing 1 changed file with 133 additions and 208 deletions.
341 changes: 133 additions & 208 deletions gandlf_collectStats
Original file line number Diff line number Diff line change
@@ -1,20 +1,141 @@
#!usr/bin/env python
#!/usr/bin/env python
# -*- coding: utf-8 -*-

import os
import argparse
import ast
from pathlib import Path
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from pathlib import Path
from io import StringIO

from GANDLF.cli import copyrightMessage
from GANDLF.utils.plot_utils import plot_all

import os
import seaborn as sns
import matplotlib.pyplot as plt
from pathlib import Path

from GANDLF.cli import copyrightMessage

def plot_all(df_training, df_validation, df_testing, output_plot_dir):
"""
Plots training, validation, and testing data for loss and other metrics.
Args:
df_training (pd.DataFrame): DataFrame containing training data.
df_validation (pd.DataFrame): DataFrame containing validation data.
df_testing (pd.DataFrame): DataFrame containing testing data.
output_plot_dir (str): Directory to save the plots.
Returns:
tuple: Tuple containing the modified training, validation, and testing DataFrames.
"""
# Drop any columns that might have "_" in the values of their rows
banned_cols = [
col
for col in df_training.columns
if any("_" in str(val) for val in df_training[col].values)
]

# Determine metrics from the column names by removing the "train_" prefix
metrics = [
col.replace("train_", "")
for col in df_training.columns
if "train_" in col and col not in banned_cols
]

# Split the values of the banned columns into multiple columns
# for df in [df_training, df_validation, df_testing]:
# for col in banned_cols:
# if df[col].dtype == "object":
# split_cols = (
# df[col]
# .str.split("_", expand=True)
# .apply(pd.to_numeric, errors="coerce")
# )
# split_cols.columns = [f"{col}_{i}" for i in range(split_cols.shape[1])]
# df.drop(columns=col, inplace=True)
# df = pd.concat([df, split_cols], axis=1)

# Check if any of the metrics is present in the column names of the dataframe
assert any(
any(metric in col for col in df_training.columns) for metric in metrics
), "None of the specified metrics is in the dataframe."

required_cols = ["epoch_no", "train_loss"]

# Check if the required columns are in the dataframe
assert all(
col in df_training.columns for col in required_cols
), "Not all required columns are in the dataframe."

epochs = len(df_training)

# Plot for loss
plt.figure(figsize=(12, 6))
if "train_loss" in df_training.columns:
sns.lineplot(data=df_training, x="epoch_no", y="train_loss", label="Training")

if "valid_loss" in df_validation.columns:
sns.lineplot(
data=df_validation, x="epoch_no", y="valid_loss", label="Validation"
)

if df_testing is not None and "test_loss" in df_testing.columns:
sns.lineplot(data=df_testing, x="epoch_no", y="test_loss", label="Testing")

plt.xlim(0, epochs - 1)
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Loss Plot")
plt.legend()
Path(output_plot_dir).mkdir(parents=True, exist_ok=True)
plt.savefig(os.path.join(output_plot_dir, "loss_plot.png"), dpi=300)
plt.close()

# Plot for other metrics
for metric in metrics:
metric_cols = [col for col in df_training.columns if metric in col]
for metric_col in metric_cols:
plt.figure(figsize=(12, 6))
if metric_col in df_training.columns:
sns.lineplot(
data=df_training,
x="epoch_no",
y=metric_col,
label=f"Training {metric_col}",
)
if metric_col.replace("train", "valid") in df_validation.columns:
sns.lineplot(
data=df_validation,
x="epoch_no",
y=metric_col.replace("train", "valid"),
label=f"Validation {metric_col}",
)
if (
df_testing is not None
and metric_col.replace("train", "test") in df_testing.columns
):
sns.lineplot(
data=df_testing,
x="epoch_no",
y=metric_col.replace("train", "test"),
label=f"Testing {metric_col}",
)
plt.xlim(0, epochs - 1)
plt.xlabel("Epoch")
plt.ylabel(metric.capitalize())
plt.title(f"{metric.capitalize()} Plot")
plt.legend()
plt.savefig(os.path.join(output_plot_dir, f"{metric}_plot.png"), dpi=300)
plt.close()

print("Plots saved successfully.")
return df_training, df_validation, df_testing


def main():
if __name__ == "__main__":
parser = argparse.ArgumentParser(
prog="GANDLF_CollectStats",
formatter_class=argparse.RawTextHelpFormatter,
Expand All @@ -26,7 +147,7 @@ def main():
"--modeldir",
metavar="",
type=str,
help="Input directory which contains testing and validation models",
help="Input directory which contains testing and validation models log files",
)
parser.add_argument(
"-o",
Expand All @@ -35,14 +156,6 @@ def main():
type=str,
help="Output directory to save stats and plot",
)
parser.add_argument(
"-c",
"--combinedplots",
metavar="",
default=False,
type=ast.literal_eval,
help="Overlays training and validation plots for both accuracy and loss (classification only).",
)

args = parser.parse_args()

Expand All @@ -52,202 +165,14 @@ def main():
outputFile = os.path.join(outputDir, "data.csv") # data file name
outputPlot = os.path.join(outputDir, "plot.png") # plot file

combinedPlots = args.combinedplots

trainingLogs = os.path.join(inputDir, "logs_training.csv")
validationLogs = os.path.join(inputDir, "logs_validation.csv")
testingLogs = os.path.join(inputDir, "logs_testing.csv")

if os.path.exists(testingLogs):
testingLogsCSV = pd.read_csv(testingLogs)

# check for classification task
if len(testingLogsCSV) == 0:
print("Classification task detected, generating accuracy and loss plots.")

# check whether user wants training + validation overlaid plots
if combinedPlots:
df_training = pd.read_csv(trainingLogs)
df_validation = pd.read_csv(validationLogs)
# Read all the files
df_training = pd.read_csv(trainingLogs)
df_validation = pd.read_csv(validationLogs)
df_testing = pd.read_csv(testingLogs) if os.path.isfile(testingLogs) else None

epochs = len(df_training)

fig, axes = plt.subplots(nrows=1, ncols=2) # set plot properties
# ensure spacing between plots
plt.subplots_adjust(wspace=0.5, hspace=0.5)
# plot training accuracy data
splot = sns.lineplot(
data=df_training,
x="epoch_no",
y="train_balanced_accuracy",
ax=axes[0],
)
# plot validation accuracy data
splot = sns.lineplot(
data=df_validation,
x="epoch_no",
y="valid_balanced_accuracy",
ax=axes[0],
)
# set limits for x-axis for proper visualization
splot.set(xlim=(0, epochs - 1))
# set limits for y-axis for proper visualization
splot.set(ylim=(0, 1))
# add labels and title to plot
splot.set(xlabel="Epoch", ylabel="Accuracy", title="Accuracy Plot")
# add legend to plot
axes[0].legend(labels=["Training", "Validation"])

# plot training loss data
splot = sns.lineplot(
data=df_training, x="epoch_no", y="train_loss", ax=axes[1]
)
# plot validation loss data
splot = sns.lineplot(
data=df_validation, x="epoch_no", y="valid_loss", ax=axes[1]
)
# set limits for x-axis for proper visualization
splot.set(xlim=(0, epochs - 1))
# add labels and title to plot
splot.set(xlabel="Epoch", ylabel="Loss", title="Loss Plot")
# add legend to plot
axes[1].legend(labels=["Training", "Validation"])
# save plot
plt.savefig(outputPlot, dpi=600)

print("Plots saved successfully.")

else:
df_training = pd.read_csv(trainingLogs)
df_validation = pd.read_csv(validationLogs)

epochs = len(df_training)

# set plot properties
fig, axes = plt.subplots(nrows=2, ncols=2)

plt.subplots_adjust(wspace=0.5, hspace=0.5)
# plot the data
splot = sns.lineplot(
data=df_training,
x="epoch_no",
y="train_balanced_accuracy",
ax=axes[0, 0],
)
splot.set(xlim=(0, epochs - 1))
splot.set(ylim=(0, 1)) # set limits for y-axis for proper visualization
# set labels
splot.set(
xlabel="Epoch", ylabel="Accuracy", title="Training Accuracy Plot"
)

# plot the data
splot = sns.lineplot(
data=df_validation,
x="epoch_no",
y="valid_balanced_accuracy",
ax=axes[0, 1],
)
splot.set(xlim=(0, epochs - 1))
splot.set(ylim=(0, 1)) # set limits for y-axis for proper visualization
# set labels
splot.set(
xlabel="Epoch", ylabel="Accuracy", title="Validation Accuracy Plot"
)
# plot the data
splot = sns.lineplot(
data=df_training, x="epoch_no", y="train_loss", ax=axes[1, 0]
)
splot.set(xlim=(0, epochs - 1))
# set labels
splot.set(xlabel="Epoch", ylabel="Loss", title="Training Loss Plot")
# plot the data
splot = sns.lineplot(
data=df_validation, x="epoch_no", y="valid_loss", ax=axes[1, 1]
)
splot.set(xlim=(0, epochs - 1))
# set labels
splot.set(xlabel="Epoch", ylabel="Loss", title="Validation Loss Plot")

plt.savefig(outputPlot, dpi=600)

print("Plots saved successfully.")

else:
print("Segmentation task detected, generating dice and loss plots.")

final_stats = "Epoch,Train_Loss,Train_Dice,Val_Loss,Val_Dice,Testing_Loss,Testing_Dice\n" # the columns that need to be present in final output; epoch is always removed

# loop through output directory
for dirs in os.listdir(inputDir):
currentTestingDir = os.path.join(inputDir, dirs)
if os.path.isdir(currentTestingDir): # go in only if it is a directory
if "testing_" in dirs: # ensure it is part of the testing structure
# loop through all validation directories
for val in os.listdir(currentTestingDir):
currentValidationDir = os.path.join(currentTestingDir, val)
if os.path.isdir(currentValidationDir):
# get all files in each directory
filesInDir = os.listdir(currentValidationDir)

for i, n in enumerate(filesInDir):
# when the log has been found, collect the final numbers
if "trainingScores_log" in n:
log_file = os.path.join(currentValidationDir, n)
with open(log_file) as f:
for line in f:
pass
final_stats = final_stats + line

data_string = StringIO(final_stats)
data_full = pd.read_csv(data_string, sep=",")
del data_full["Epoch"] # no need for epoch
data_full.to_csv(outputFile, index=False) # save updated data

# perform deep copy
data_loss = data_full.copy()
data_dice = data_full.copy()
# set the datasets that need to be plotted
cols = [
"Train",
"Val",
"Testing",
]
for i in cols:
del data_dice[i + "_Loss"] # keep only dice
del data_loss[i + "_Dice"] # keep only loss
# rename the columns
data_loss.rename(columns={i + "_Loss": i}, inplace=True)
# rename the columns
data_dice.rename(columns={i + "_Dice": i}, inplace=True)
# set plot properties
fig, axes = plt.subplots(nrows=1, ncols=2, constrained_layout=True)
# plot the data
bplot = sns.boxplot(
data=data_dice, width=0.5, palette="colorblind", ax=axes[0]
)
# set limits for y-axis for proper visualization
bplot.set(ylim=(0, 1))
# set labels
bplot.set(xlabel="Dataset", ylabel="Dice", title="Dice plot")
# rotate so that everything is visible
bplot.set_xticklabels(bplot.get_xticklabels(), rotation=15, ha="right")
# plot the data
bplot = sns.boxplot(
data=data_loss, width=0.5, palette="colorblind", ax=axes[1]
)
# set limits for y-axis for proper visualization
bplot.set(ylim=(0, 1))
# set labels
bplot.set(xlabel="Dataset", ylabel="Loss", title="Loss plot")
# rotate so that everything is visible
bplot.set_xticklabels(bplot.get_xticklabels(), rotation=15, ha="right")

plt.savefig(outputPlot, dpi=600)

print("Plots saved successfully.")


# main function
if __name__ == "__main__":
main()
# Check for metrics in columns and do tight plots
plot_all(df_training, df_validation, df_testing, outputPlot)

0 comments on commit 8a8e437

Please sign in to comment.