Skip to content

Commit

Permalink
WIP: Initial refactoring of disease specific logic from person.cu
Browse files Browse the repository at this point in the history
It's a bit grim due to templating
  • Loading branch information
ptheywood committed Dec 17, 2024
1 parent 4f8bf68 commit 5cc396b
Showing 1 changed file with 86 additions and 57 deletions.
143 changes: 86 additions & 57 deletions src/exateppabm/person.cu
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,55 @@
namespace exateppabm {
namespace person {

/**
* Device utility function for Increasing an individuals infection counter, for more legible agent code
*
* Templated for due to the templated DeviceAPI object
*
* @todo - consider moving to the disease namespace?
*/
template<typename MsgIn, typename MsgOut>
FLAMEGPU_DEVICE_FUNCTION void incrementInfectionCounter(flamegpu::DeviceAPI<MsgIn, MsgOut>* FLAMEGPU) {
FLAMEGPU->template setVariable<std::uint32_t>(v::INFECTION_COUNT, FLAMEGPU->template getVariable<std::uint32_t>(v::INFECTION_COUNT) + 1);
}

/**
* Device utility function for when an individual is exposed, moving from susceptible to exposed
*
* Templated for due to the templated DeviceAPI object
* *
* @todo - move this to disease::SEIR
*/
template<typename MsgIn, typename MsgOut>
FLAMEGPU_DEVICE_FUNCTION void susceptibleToExposed(flamegpu::DeviceAPI<MsgIn, MsgOut>* FLAMEGPU, disease::SEIR::InfectionStateUnderlyingType& infectionStatus) {
// Generate how long the individual will be in the exposed for.
float mean = FLAMEGPU->environment.template getProperty<float>("mean_time_to_infected");
float sd = FLAMEGPU->environment.template getProperty<float>("sd_time_to_infected");
float stateDuration = (FLAMEGPU->random.template normal<float>() * sd) + mean;

// Update the referenced value containing the individuals current infections status, used to reduce branching within a device for loop.
infectionStatus = disease::SEIR::InfectionState::Exposed;
// Update individuals infection state in global agent memory
FLAMEGPU->template setVariable<disease::SEIR::InfectionStateUnderlyingType>(v::INFECTION_STATE, infectionStatus);
// Update the individual infection state duration in global agent memory
FLAMEGPU->template setVariable<float>(person::v::INFECTION_STATE_DURATION, stateDuration);

// Increment the infection counter for this individual
person::incrementInfectionCounter(FLAMEGPU);
}

/**
* Device utility function to get an individuals current infection status from global agent memory
*
* Templated for due to the templated DeviceAPI object
* *
* @todo - move this to disease::SEIR
*/
template<typename MsgIn, typename MsgOut>
FLAMEGPU_DEVICE_FUNCTION disease::SEIR::InfectionStateUnderlyingType getCurrentInfectionStatus(flamegpu::DeviceAPI<MsgIn, MsgOut>* FLAMEGPU) {
return FLAMEGPU->template getVariable<disease::SEIR::InfectionStateUnderlyingType>(v::INFECTION_STATE);
}

/**
* Agent function for person agents to emit their public information, i.e. infection status to their household
*/
Expand All @@ -29,7 +78,7 @@ FLAMEGPU_AGENT_FUNCTION(emitHouseholdStatus, flamegpu::MessageNone, flamegpu::Me
FLAMEGPU->message_out.setVariable<std::uint32_t>(v::HOUSEHOLD_IDX, householdIdx);

FLAMEGPU->message_out.setVariable<disease::SEIR::InfectionStateUnderlyingType>(v::
INFECTION_STATE, FLAMEGPU->getVariable<disease::SEIR::InfectionStateUnderlyingType>(v::INFECTION_STATE));
INFECTION_STATE, getCurrentInfectionStatus(FLAMEGPU));
FLAMEGPU->message_out.setVariable<demographics::AgeUnderlyingType>(v::AGE_DEMOGRAPHIC, FLAMEGPU->getVariable<demographics::AgeUnderlyingType>(v::AGE_DEMOGRAPHIC));
// Set the message key, the house hold idx for bucket messaging @Todo
FLAMEGPU->message_out.setKey(householdIdx);
Expand Down Expand Up @@ -65,43 +114,36 @@ FLAMEGPU_AGENT_FUNCTION(interactHousehold, flamegpu::MessageBucket, flamegpu::Me
p_s2e *= relativeSusceptibility;

// Check if the current individual is susceptible to being infected
auto infectionState = FLAMEGPU->getVariable<disease::SEIR::InfectionStateUnderlyingType>(v::INFECTION_STATE);
auto infectionState = getCurrentInfectionStatus(FLAMEGPU);

// @todo - this will need to change for contact tracing, the message interaction will need to occur regardless.
// Only check interactions from this individual if they are susceptible. @todo - this will need to change for contact tracing.

Check failure on line 119 in src/exateppabm/person.cu

View workflow job for this annotation

GitHub Actions / cpplint (12.0, ubuntu-22.04)

Line ends in whitespace. Consider deleting these extra spaces.
if (infectionState == disease::SEIR::Susceptible) {
// Variable to store the duration of the exposed phase (if exposed)
float stateDuration = 0.f;

// Bool to track if individual newly exposed - used to move expensive operations outside the message iteration loop.
bool newlyExposed = false;
// Iterate messages from anyone within the household
for (const auto &message : FLAMEGPU->message_in(householdIdx)) {
// Ignore self messages (can't infect oneself)
if (message.getVariable<flamegpu::id_t>(message::household_status::ID) != id) {
// Ignore messages from other households
// Ignore messages from other households (this should be superfluous.)
if (message.getVariable<std::uint32_t>(v::HOUSEHOLD_IDX) == householdIdx) {
// Check if the other agent is infected
if (message.getVariable<disease::SEIR::InfectionStateUnderlyingType>(v::INFECTION_STATE) == disease::SEIR::InfectionState::Infected) {
// Roll a dice
float r = FLAMEGPU->random.uniform<float>();
if (r < p_s2e) {
// I have been exposed
infectionState = disease::SEIR::InfectionState::Exposed;
// Generate how long until I am infected
float mean = FLAMEGPU->environment.getProperty<float>("mean_time_to_infected");
float sd = FLAMEGPU->environment.getProperty<float>("sd_time_to_infected");
stateDuration = (FLAMEGPU->random.normal<float>() * sd) + mean;
// @todo - for now only any exposure matters. This may want to change when quantity of exposure is important?
// Increment the infection counter for this individual
FLAMEGPU->setVariable<std::uint32_t>(v::INFECTION_COUNT, FLAMEGPU->getVariable<std::uint32_t>(v::INFECTION_COUNT) + 1);
// set a flag indicating that the individual has been exposed in this message iteration loop
newlyExposed = true;
// break out of the message iteration loop, currently no need to check for multiple exposures on the same day.
break;
}
}
}
}
}
// If newly exposed, store the value in global device memory.
if (infectionState == disease::SEIR::InfectionState::Exposed) {
FLAMEGPU->setVariable<disease::SEIR::InfectionStateUnderlyingType>(v::INFECTION_STATE, infectionState);
FLAMEGPU->setVariable<float>(person::v::INFECTION_STATE_DURATION, stateDuration);
// If newly exposed, update agent data and generate new seir state information. This is done outside the message iteration loop to be more GPU-shaped.
if (newlyExposed) {
// Transition from susceptible to exposed in SEIR
susceptibleToExposed(FLAMEGPU, infectionState);
}
}

Expand All @@ -123,7 +165,7 @@ FLAMEGPU_AGENT_FUNCTION(emitWorkplaceStatus, flamegpu::MessageNone, flamegpu::Me
FLAMEGPU->message_out.setVariable<std::uint32_t>(v::WORKPLACE_IDX, workplaceIdx);

FLAMEGPU->message_out.setVariable<disease::SEIR::InfectionStateUnderlyingType>(v::
INFECTION_STATE, FLAMEGPU->getVariable<disease::SEIR::InfectionStateUnderlyingType>(v::INFECTION_STATE));
INFECTION_STATE, getCurrentInfectionStatus(FLAMEGPU));
FLAMEGPU->message_out.setVariable<demographics::AgeUnderlyingType>(v::AGE_DEMOGRAPHIC, FLAMEGPU->getVariable<demographics::AgeUnderlyingType>(v::AGE_DEMOGRAPHIC));

// Set the message key, the house hold idx for bucket messaging @Todo
Expand Down Expand Up @@ -168,12 +210,12 @@ FLAMEGPU_AGENT_FUNCTION(interactWorkplace, flamegpu::MessageArray, flamegpu::Mes
p_s2e *= relativeSusceptibility;

// Check if the current individual is susceptible to being infected
auto infectionState = FLAMEGPU->getVariable<disease::SEIR::InfectionStateUnderlyingType>(v::INFECTION_STATE);
auto infectionState = getCurrentInfectionStatus(FLAMEGPU);

// @todo - this will need to change for contact tracing, the message interaction will need to occur regardless.
// Only check interactions from this individual if they are susceptible. @todo - this will need to change for contact tracing.
if (infectionState == disease::SEIR::Susceptible) {
// Variable to store the duration of the exposed phase (if exposed)
float stateDuration = 0.f;
// Bool to track if individual newly exposed - used to move expensive operations outside the message iteration loop.
bool newlyExposed = false;
// Iterate my downstream neighbours (the graph is undirected, so no need to iterate in and out
auto workplaceGraph = FLAMEGPU->environment.getDirectedGraph("WORKPLACE_DIGRAPH");
std::uint32_t myVertexIndex = workplaceGraph.getVertexIndex(id);
Expand All @@ -194,26 +236,20 @@ FLAMEGPU_AGENT_FUNCTION(interactWorkplace, flamegpu::MessageArray, flamegpu::Mes
// Roll a dice
float r = FLAMEGPU->random.uniform<float>();
if (r < p_s2e) {
// I have been exposed
infectionState = disease::SEIR::InfectionState::Exposed;
// Generate how long until I am infected
float mean = FLAMEGPU->environment.getProperty<float>("mean_time_to_infected");
float sd = FLAMEGPU->environment.getProperty<float>("sd_time_to_infected");
stateDuration = (FLAMEGPU->random.normal<float>() * sd) + mean;
// @todo - for now only any exposure matters. This may want to change when quantity of exposure is important?
// Increment the infection counter for this individual
FLAMEGPU->setVariable<std::uint32_t>(v::INFECTION_COUNT, FLAMEGPU->getVariable<std::uint32_t>(v::INFECTION_COUNT) + 1);
// set a flag indicating that the individual has been exposed in this message iteration loop
newlyExposed = true;
// break out of the message iteration loop, currently no need to check for multiple exposures on the same day.
break;
}
}
}
}
}
}
// If newly exposed, store the value in global device memory.
if (infectionState == disease::SEIR::InfectionState::Exposed) {
FLAMEGPU->setVariable<disease::SEIR::InfectionStateUnderlyingType>(v::INFECTION_STATE, infectionState);
FLAMEGPU->setVariable<float>(person::v::INFECTION_STATE_DURATION, stateDuration);
// If newly exposed, update agent data and generate new seir state information. This is done outside the message iteration loop to be more GPU-shaped.
if (newlyExposed) {
// Transition from susceptible to exposed in SEIR
susceptibleToExposed(FLAMEGPU, infectionState);
}
}

Expand Down Expand Up @@ -348,7 +384,7 @@ FLAMEGPU_AGENT_FUNCTION(emitRandomDailyNetworkStatus, flamegpu::MessageNone, fla
// FLAMEGPU->message_out.setVariable<flamegpu::id_t>(person::message::random_daily_status::ID, FLAMEGPU->getVariable<flamegpu::id_t>(person::v::ID));

FLAMEGPU->message_out.setVariable<disease::SEIR::InfectionStateUnderlyingType>(v::
INFECTION_STATE, FLAMEGPU->getVariable<disease::SEIR::InfectionStateUnderlyingType>(v::INFECTION_STATE));
INFECTION_STATE, getCurrentInfectionStatus(FLAMEGPU));
FLAMEGPU->message_out.setVariable<demographics::AgeUnderlyingType>(v::AGE_DEMOGRAPHIC, FLAMEGPU->getVariable<demographics::AgeUnderlyingType>(v::AGE_DEMOGRAPHIC));

// Set the message array message index to the agent's id.
Expand Down Expand Up @@ -382,13 +418,12 @@ FLAMEGPU_AGENT_FUNCTION(interactRandomDailyNetwork, flamegpu::MessageArray, flam
p_s2e *= relativeSusceptibility;

// Check if the current individual is susceptible to being infected
auto infectionState = FLAMEGPU->getVariable<disease::SEIR::InfectionStateUnderlyingType>(v::INFECTION_STATE);
auto infectionState = getCurrentInfectionStatus(FLAMEGPU);

// @todo - this will need to change for contact tracing, the message interaction will need to occur regardless.
// Only check interactions from this individual if they are susceptible. @todo - this will need to change for contact tracing.

Check failure on line 423 in src/exateppabm/person.cu

View workflow job for this annotation

GitHub Actions / cpplint (12.0, ubuntu-22.04)

Line ends in whitespace. Consider deleting these extra spaces.
if (infectionState == disease::SEIR::Susceptible) {
// Variable to store the duration of the exposed phase (if exposed)
float stateDuration = 0.f;

// Bool to track if individual newly exposed - used to move expensive operations outside the message iteration loop.
bool newlyExposed = false;
// For each interaction this agent is set to perform
const std::uint32_t randomInteractionCount = FLAMEGPU->getVariable<std::uint32_t>(person::v::RANDOM_INTERACTION_COUNT);
for (std::uint32_t randomInteractionIdx = 0; randomInteractionIdx < randomInteractionCount; ++randomInteractionIdx) {
Expand All @@ -403,24 +438,18 @@ FLAMEGPU_AGENT_FUNCTION(interactRandomDailyNetwork, flamegpu::MessageArray, flam
// Roll a dice
float r = FLAMEGPU->random.uniform<float>();
if (r < p_s2e) {
// I have been exposed
infectionState = disease::SEIR::InfectionState::Exposed;
// Generate how long until I am infected
float mean = FLAMEGPU->environment.getProperty<float>("mean_time_to_infected");
float sd = FLAMEGPU->environment.getProperty<float>("sd_time_to_infected");
stateDuration = (FLAMEGPU->random.normal<float>() * sd) + mean;
// @todo - for now only any exposure matters. This may want to change when quantity of exposure is important?
// Increment the infection counter for this individual
FLAMEGPU->setVariable<std::uint32_t>(v::INFECTION_COUNT, FLAMEGPU->getVariable<std::uint32_t>(v::INFECTION_COUNT) + 1);
// set a flag indicating that the individual has been exposed in this message iteration loop
newlyExposed = true;
// break out of the loop over today's random interactions - can only be exposed once
break;
}
}
}
}
// If newly exposed, store the value in global device memory.
if (infectionState == disease::SEIR::InfectionState::Exposed) {
FLAMEGPU->setVariable<disease::SEIR::InfectionStateUnderlyingType>(v::INFECTION_STATE, infectionState);
FLAMEGPU->setVariable<float>(person::v::INFECTION_STATE_DURATION, stateDuration);
// If newly exposed, update agent data and generate new seir state information. This is done outside the message iteration loop to be more GPU-shaped.
if (newlyExposed) {
// Transition from susceptible to exposed in SEIR
susceptibleToExposed(FLAMEGPU, infectionState);
}
}

Expand Down

0 comments on commit 5cc396b

Please sign in to comment.