Skip to content

Commit

Permalink
#7 update runner.py
Browse files Browse the repository at this point in the history
  • Loading branch information
khirotaka committed Feb 12, 2020
1 parent da9a0b4 commit a8f16d9
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 22 deletions.
3 changes: 1 addition & 2 deletions enchanter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,4 @@

del Dict, Union, Any, List
del io, os, time, deepcopy
del np, torch, tqdm, DataLoader, Dataset, BaseEstimator
del accuracy_score
del np, torch, tqdm, DataLoader, Dataset, base, metrics
54 changes: 35 additions & 19 deletions enchanter/engine/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,8 @@

import torch
import numpy as np
from sklearn.base import BaseEstimator
from sklearn import base, metrics
from torch.utils.data import DataLoader, Dataset
from sklearn.metrics import accuracy_score

from enchanter.engine import modules

Expand All @@ -27,7 +26,10 @@
from tqdm import tqdm


class BaseRunner(BaseEstimator):
class BaseRunner(base.BaseEstimator):
""""
BaseRunner
"""
def __init__(self, model, criterion, optimizer, optim_config, device=None, experiment=None, scheduler=None):
"""
Expand Down Expand Up @@ -120,10 +122,10 @@ def train(self, dataset, epochs, batch_size, verbose=True, shuffle=False, checkp
for epoch in epoch_bar:
self.model.train()
step_bar = tqdm(train_loader, desc="Training", leave=False) if verbose else train_loader
for i, (x, y) in enumerate(step_bar):
for i, (x_train, y_train) in enumerate(step_bar):

self.optimizer.zero_grad()
loss = self.forward(x, y)
loss = self.forward(x_train, y_train)

loss.backward()
self.optimizer.step()
Expand Down Expand Up @@ -219,7 +221,18 @@ def predict(self, x: Union[torch.Tensor, np.ndarray]) -> np.ndarray:
out = self.model(x).cpu().numpy()
return out

def evaluate(self, x: Any, y: Any, batch_size: int = 1):
def evaluate(self, *args, **kwargs) -> Any:
"""
テストデータセットを用いてモデルの評価を行うメソッド。
このクラスを継承して新しいRunnerを作る場合必ず定義する必要がある。
Args:
*args:
**kwargs:
Returns:
"""
raise NotImplementedError

def save_checkpoint(self) -> dict:
Expand Down Expand Up @@ -282,6 +295,9 @@ def load(self, filename: str, map_location: str = "cpu") -> None:


class ClassificationRunner(BaseRunner):
"""
Runner for Classification task.
"""
def predict(self, x: Union[torch.Tensor, np.ndarray]) -> np.ndarray:
"""
Expand All @@ -295,24 +311,24 @@ def predict(self, x: Union[torch.Tensor, np.ndarray]) -> np.ndarray:
predict = np.argmax(out, axis=-1)
return predict

def evaluate(self, x, y=None, batch_size: int = 1, verbose: bool = True, metrics: List[callable] = None) -> Dict:
def evaluate(self, x, y=None, batch_size: int = 1, verbose: bool = True, metric_fn: List[callable] = None) -> Dict:
"""
Args:
x (Union[Union[np.ndarray, torch.Tensor], Dataset]):
y (Union[Union[np.ndarray, torch.Tensor], None]):
batch_size (int):
verbose (bool):
metrics (List[callable]): sklearn.metrics のような、 func(y_true, y_pred) の形で提供される評価関数を格納した配列
metric_fn (List[callable]): sklearn.metrics のような、 func(y_true, y_pred) の形で提供される評価関数を格納した配列
Returns:
losses (float):
accuracy (float):
"""
if metrics is None:
metrics = [accuracy_score]
if metric_fn is None:
metric_fn = [metrics.accuracy_score]
else:
metrics.append(accuracy_score)
metric_fn.append(metrics.accuracy_score)

total = 0.0
losses = 0.0
Expand All @@ -329,23 +345,23 @@ def evaluate(self, x, y=None, batch_size: int = 1, verbose: bool = True, metrics
if verbose else DataLoader(x, batch_size=batch_size, shuffle=False)

with torch.no_grad():
for x, y in loader:
total += y.shape[0]
for data, target in loader:
total += target.shape[0]

x = x.to(self.device)
y = y.to(self.device)
data = data.to(self.device)
target = target.to(self.device)

loss = self.criterion(self.model(x), y).cpu().item()
predict = self.predict(x)
loss = self.criterion(self.model(data), target).cpu().item()
predict = self.predict(data)
losses += loss

labels.append(y.cpu().numpy())
labels.append(target.cpu().numpy())
predicts.append(predict)

labels = np.hstack(labels)
predicts = np.hstack(predicts)

for func in metrics:
for func in metric_fn:
metric_values[func.__name__] = func(labels, predicts)

metric_values["loss"] = losses / total
Expand Down
2 changes: 1 addition & 1 deletion test/engine/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def main():
}
}
)
metrics = runner.evaluate(test_ds, batch_size=64, metrics=[accuracy_score])
metrics = runner.evaluate(test_ds, batch_size=64, metric_fn=[accuracy_score])

img, label = next(iter(DataLoader(test_ds, batch_size=32)))
print(runner.predict(img))
Expand Down

0 comments on commit a8f16d9

Please sign in to comment.