Skip to content

Commit

Permalink
try faster tests
Browse files Browse the repository at this point in the history
  • Loading branch information
euhoro committed Jun 18, 2024
1 parent 6f260c5 commit 7fe4105
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 177 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/basic.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,11 @@ jobs:
docker run --name redis -d -p 6379:6379 redis
- name: Start Uvicorn server
run: |
nohup SETTINGS_MODE=redis uvicorn app.main:app --host 127.0.0.1 --port 8000 &
nohup SETTINGS_MODE=text uvicorn app.main:app --host 127.0.0.1 --port 8000 &
sleep 5 # Wait for the server to start
- name: Run tests with pytest
run: |
SETTINGS_MODE=redis pytest tests
SETTINGS_MODE=text pytest tests
- name: Stop Uvicorn server
if: always()
run: |
Expand Down
176 changes: 1 addition & 175 deletions atm_service_redis.py
Original file line number Diff line number Diff line change
@@ -1,44 +1,6 @@
import threading
import time
import uuid
from abc import ABC, abstractmethod

import redis
import json
from pydantic import BaseModel, Field
from typing import Dict
import random
from concurrent.futures import ThreadPoolExecutor, as_completed

# from atm_repository_sqllite import SQLiteInventoryService
from atm_service_json_file import JSONFileInventoryService

#
# class Inventory(BaseModel):
# BILL: Dict[float, int] = Field(default_factory=dict)
# COIN: Dict[float, int] = Field(default_factory=dict)
#
#
# class InventoryService(ABC):
# @abstractmethod
# def read_inventory(self) -> Inventory:
# pass
#
# @abstractmethod
# def write_inventory(self, inventory: Inventory):
# pass
#
# @abstractmethod
# def restart(self):
# pass
#
# @abstractmethod
# def acquire_lock(self):
# pass
#
# @abstractmethod
# def release_lock(self):
# pass
from common import InventoryService, Inventory


Expand Down Expand Up @@ -76,140 +38,4 @@ def acquire_lock(self):

def release_lock(self):
if self.client.get(self.lock_name) == self.lock_value:
self.client.delete(self.lock_name)


class ATMService:
def __init__(self, inventory_service):
self.inventory_service = inventory_service

def perform_transaction(self, action, item_type, denomination, quantity, retries=3):
while retries > 0:
try:
self.inventory_service.acquire_lock()
try:
inventory = self.inventory_service.read_inventory()

denomination = float(denomination)
quantity = int(quantity)

if action == "put":
if item_type == "BILL":
if denomination in inventory.BILL:
inventory.BILL[denomination] += quantity
else:
inventory.BILL[denomination] = quantity
elif item_type == "COIN":
if denomination in inventory.COIN:
inventory.COIN[denomination] += quantity
else:
inventory.COIN[denomination] = quantity
self.inventory_service.write_inventory(inventory)
print(f"Put {quantity} of {denomination} {item_type.lower()}. New state: {inventory}")
return True

elif action == "retrieve":
if item_type == "BILL":
if denomination in inventory.BILL and inventory.BILL[denomination] >= quantity:
inventory.BILL[denomination] -= quantity
else:
return False
elif item_type == "COIN":
if denomination in inventory.COIN and inventory.COIN[denomination] >= quantity:
inventory.COIN[denomination] -= quantity
else:
return False
self.inventory_service.write_inventory(inventory)
print(f"Retrieved {quantity} of {denomination} {item_type.lower()}. New state: {inventory}")
return True

return False
finally:
self.inventory_service.release_lock()
except Exception as e:
print(f"Error performing transaction: {e}. Retrying...")
retries -= 1
time.sleep(0.1)
return False

def get_total(self):
self.inventory_service.acquire_lock()
try:
inventory = self.inventory_service.read_inventory()
total = 0
for denomination, quantity in inventory.BILL.items():
total += float(denomination) * quantity
for denomination, quantity in inventory.COIN.items():
total += float(denomination) * quantity
return total
finally:
self.inventory_service.release_lock()


# Stress Testing function
def stress_test(timeout=10):
redis_client = redis.StrictRedis(host='localhost', port=6379, db=0)
inventory_service = RedisInventoryService(redis_client)

# inventory_service = SQLiteInventoryService()
# inventory_service = JSONFileInventoryService()
atm_service = ATMService(inventory_service)

# Restart to ensure a clean state
inventory_service.restart()

transaction_log = []
num_threads = 20 # Set number of threads for stress testing
actions = ["put", "retrieve"]
item_types = ["BILL", "COIN"]
denominations = {
"BILL": [200.0, 100.0, 20.0],
"COIN": [10.0, 5.0, 1.0, 0.1, 0.01]
}

def perform_random_transaction():
action = random.choice(actions)
item_type = random.choice(item_types)
denomination = random.choice(denominations[item_type])
quantity = random.randint(1, 10)

success = atm_service.perform_transaction(action, item_type, denomination, quantity)
if success:
transaction_log.append((action, item_type, denomination, quantity))
print(f"Transaction: {action.capitalize()} {quantity} of {denomination} {item_type.lower()}.")

# Initial total
initial_total = atm_service.get_total()
print(f"Initial Total: {initial_total}")

# Create threads for random transactions
threads = []
for i in range(num_threads):
t = threading.Thread(target=perform_random_transaction)
threads.append(t)
t.start()

# Wait for all threads to complete
for t in threads:
t.join(timeout)

# Final total
final_total = atm_service.get_total()
print(f"Final Total: {final_total}")

# Calculate the expected total from transaction log
expected_total = initial_total
for action, item_type, denomination, quantity in transaction_log:
if action == "put":
expected_total += float(denomination) * quantity
elif action == "retrieve":
expected_total -= float(denomination) * quantity

print(f"Expected Total from Transactions: {expected_total}")

# Assert the final total matches the expected total
assert final_total == expected_total, "The final total does not match the expected total from transactions"


# Run the stress test with a timeout of 10 seconds
#stress_test(timeout=20)
self.client.delete(self.lock_name)
78 changes: 78 additions & 0 deletions tests/test_stress.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
#
# # Stress Testing function
# import threading
# from random import random
#
# import redis
#
# from atm_service import ATMService
# from atm_service_redis import RedisInventoryService
# from common import COIN, BILL
#
#
# def stress_check(timeout=10):
# redis_client = redis.StrictRedis(host='localhost', port=6379, db=0)
# inventory_service = RedisInventoryService(redis_client)
#
# # inventory_service = SQLiteInventoryService()
# # inventory_service = JSONFileInventoryService()
# atm_service = ATMService(inventory_service)
#
# # Restart to ensure a clean state
# inventory_service.restart()
#
# transaction_log = []
# num_threads = 20 # Set number of threads for stress testing
# actions = ["put", "retrieve"]
# item_types = [BILL, COIN]
# denominations = {
# "BILL": [200.0, 100.0, 20.0],
# "COIN": [10.0, 5.0, 1.0, 0.1, 0.01]
# }
#
# def perform_random_transaction():
# action = random.choice(actions)
# item_type = random.choice(item_types)
# denomination = random.choice(denominations[item_type])
# quantity = random.randint(1, 10)
#
# success = atm_service.perform_transaction(action, item_type, denomination, quantity)
# if success:
# transaction_log.append((action, item_type, denomination, quantity))
# print(f"Transaction: {action.capitalize()} {quantity} of {denomination} {item_type.lower()}.")
#
# # Initial total
# initial_total = atm_service.get_total()
# print(f"Initial Total: {initial_total}")
#
# # Create threads for random transactions
# threads = []
# for i in range(num_threads):
# t = threading.Thread(target=perform_random_transaction)
# threads.append(t)
# t.start()
#
# # Wait for all threads to complete
# for t in threads:
# t.join(timeout)
#
# # Final total
# final_total = atm_service.get_total()
# print(f"Final Total: {final_total}")
#
# # Calculate the expected total from transaction log
# expected_total = initial_total
# for action, item_type, denomination, quantity in transaction_log:
# if action == "put":
# expected_total += float(denomination) * quantity
# elif action == "retrieve":
# expected_total -= float(denomination) * quantity
#
# print(f"Expected Total from Transactions: {expected_total}")
#
# # Assert the final total matches the expected total
# assert final_total == expected_total, "The final total does not match the expected total from transactions"
#
#
# def test_stress():
# stress_check()

0 comments on commit 7fe4105

Please sign in to comment.