Skip to content

Commit

Permalink
no more LEGACY - ANN helpers
Browse files Browse the repository at this point in the history
  • Loading branch information
forefire committed Aug 26, 2024
1 parent a0fb8cd commit f425339
Show file tree
Hide file tree
Showing 8 changed files with 62 additions and 66 deletions.
43 changes: 0 additions & 43 deletions src/ANNPropagationModel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,42 +75,13 @@ int ANNPropagationModel::isInitialized =
/* constructor */
ANNPropagationModel::ANNPropagationModel(const int & mindex, DataBroker* db)
: PropagationModel(mindex, db) {
windReductionFactor = params->getDouble("windReductionFactor");

/* slope = registerProperty("slope");
normalWind = registerProperty("normalWind");
Rhod = registerProperty("fuel.Rhod");
sd = registerProperty("fuel.sd");
Ta = registerProperty("fuel.Ta");
std::cout << slope<<" "<< normalWind<<" "<< Rhod<<" "<< sd<<" "<< Ta<<endl;
if (numProperties > 0) properties = new double[numProperties];
std::string annPath = params->getParameter("FFANNPropagationModelPath");
loadNetwork(annPath); // Load the ANN model from file
dataBroker->registerPropagationModel(this);*/


windReductionFactor = params->getDouble("windReductionFactor");

// Load network first to access names
std::string annPath = params->getParameter("FFANNPropagationModelPath");

annNetwork.loadFromFile(annPath.c_str());



// Dynamically register properties based on network input names
properties = new double[annNetwork.inputNames.size()];
for (size_t i = 0; i < annNetwork.inputNames.size(); ++i) {
registerProperty(annNetwork.inputNames[i]);
std::cout << "Registered property: " << annNetwork.inputNames[i] << std::endl;
}

dataBroker->registerPropagationModel(this);



}


Expand All @@ -129,21 +100,7 @@ string ANNPropagationModel::getName(){

double ANNPropagationModel::getSpeed(double* valueOf) {
std::vector<float> inputs(numProperties);
/*std::cout << "firing Neurons in [";
for (size_t i = 0; i < numProperties; ++i) {
inputs[i] = static_cast<float>(valueOf[i]);
std::cout << inputs[i];
if (i < numProperties - 1) std::cout << ", ";
}*/

// Process the inputs through the network
std::vector<float> outputs = annNetwork.processInput(inputs);

// std::cout << "] out: " << outputs[0] << std::endl; // Only use endl here to flush the output

// Return the first output converted to double
return static_cast<double>(abs(outputs[0]));
}

Expand Down
5 changes: 2 additions & 3 deletions src/BMapLoggerForANNTraining.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,13 +72,13 @@ BMapLoggerForANNTraining::BMapLoggerForANNTraining(const int & mindex, DataBroke


annNetwork.loadFromFile(annPath.c_str());
csvfile.open(csvPath);
csvfile.open(csvPath);
properties = new double[annNetwork.inputNames.size() + 1];
csvfile << "ROS";
registerProperty("arrival_time_gradient");
for (const auto& inputName : annNetwork.inputNames) {
csvfile << ";"<< inputName ;
std::cout << "Registered property: " << inputName << std::endl;

registerProperty(inputName);
}
csvfile << std::endl;
Expand All @@ -102,7 +102,6 @@ string BMapLoggerForANNTraining::getName(){


double BMapLoggerForANNTraining::getSpeed(double* valueOf) {

double RosVal = 0.0;
if (valueOf[0] > 0){
RosVal = 1.0/valueOf[0];
Expand Down
2 changes: 0 additions & 2 deletions src/Command.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1039,12 +1039,10 @@ int Command::loadData(const string& arg, size_t& numTabs){
if (args.size() == 2){

simParam->setParameter("NetCDFfile", args[0]);
cout<<" Loading data "<<args[0]<<endl;
try
{
NcFile dataFile(path.c_str(), NcFile::read);
if (!dataFile.isNull()) {
cout<<" NC0 data "<<endl;
NcVar domVar = dataFile.getVar("domain");
if (!domVar.isNull()) {
map<string,NcVarAtt> attributeList = domVar.getAtts();
Expand Down
4 changes: 2 additions & 2 deletions src/DataBroker.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -307,8 +307,8 @@ void DataBroker::registerLayer(string name, DataLayer<double>* layer) {
// if added
//ForcedROSLayer = ;
forcedArrivalTimeLayer = new TimeGradientDataLayer<double>("arrival_time_gradient", layer,
params->getDouble("spatialIncrement"));
cout<<"forced ROS"<<endl;
params->getDouble("LookAheadDistanceForeTimeGradientDataLayer"));
registerLayer("arrival_time_gradient", forcedArrivalTimeLayer);
}
if (name.find("moisture") != string::npos){
Expand Down
51 changes: 42 additions & 9 deletions src/FireDomain.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -705,6 +705,7 @@ void FireDomain::readMultiDomainMetadata(){
}
}


FluxModel* FireDomain::fluxModelInstanciation(const int& index, string modelname){
FluxModelMap::iterator fmodel;

Expand All @@ -723,6 +724,18 @@ void FireDomain::readMultiDomainMetadata(){
}

void FireDomain::registerPropagationModel(const int& index, PropagationModel* model){
// Join all strings in the vector into a single string separated by semicolons
std::string properties = "";
for(const std::string& prop : model->wantedProperties) {
if(properties.empty()) {
properties = prop; // First element, add without semicolon
} else {
properties += ";" + prop; // Subsequent elements, add with semicolon
}
}

// Set the parameter with the joined string
params->setParameter(model->getName() + ".keys", properties);
propModelsTable[index] = model;
}

Expand Down Expand Up @@ -1369,6 +1382,10 @@ void FireDomain::readMultiDomainMetadata(){

return propModelsTable[modelIndex]->getSpeedForNode(fn) * propagationSpeedAdjustmentFactor;
}

vector<string> FireDomain::getFirstPropagationModelKeys() {
return propModelsTable[0]->wantedProperties;
}

// Computing the flux at a given location according to a given flux model
double FireDomain::getModelValueAt(int& modelIndex
Expand Down Expand Up @@ -4695,19 +4712,27 @@ void FireDomain::loadWindDataInBinary(double refTime){
params->setInt("refDay", pday);

double max_time = params->getDouble("InitTime");
double allDataAtime[FSPACE_DIM2][FSPACE_DIM1];
// Dynamically allocate the matrix on the heap

double* allDataAtime = new double[globalBMapSizeY * globalBMapSizeX];

atime.getVar(allDataAtime);

double tmpval = 0;
for (size_t i = 0; i < globalBMapSizeX; i++) {
for (size_t j = 0; j < globalBMapSizeY; j++) {
tmpval = allDataAtime[j][i];
tmpval = allDataAtime[j * globalBMapSizeX + i];
if ((tmpval < max_time) && (tmpval > -9999)) {
this->setArrivalTime(i , j,tmpval );
}
}
}
dataFile.close();


delete[] allDataAtime;


}
catch (std::exception const & e)
{
Expand Down Expand Up @@ -4744,23 +4769,26 @@ void FireDomain::loadWindDataInBinary(double refTime){
vector<NcDim> dims = {yDim, xDim}; // Order is important for visualization
NcVar atime = dataFile.addVar("arrival_time_of_front", ncDouble, dims);
atime.setCompression(true, true, 6);



double* matrix = new double[globalBMapSizeY * globalBMapSizeX];



double matrix[globalBMapSizeY][globalBMapSizeX];

for (size_t i = 0; i < globalBMapSizeX; i++) {
for (size_t j = 0; j < globalBMapSizeY; j++) {

double tmpval = this->getArrivalTime(i , j );
if (tmpval == numeric_limits<double>::infinity()) {
matrix[j][i] = -9999 ;
if (std::isinf(tmpval)) {
matrix[j * globalBMapSizeX + i] = -9999 ;
}else{
matrix[j][i] = tmpval;
matrix[j * globalBMapSizeX + i] = tmpval;
}
}
}

// Write the arrival time data
atime.putVar(&matrix[0][0]);
atime.putVar(&matrix[0]);
// Add domain and reference attributes
NcDim domdim = dataFile.addDim("domdim", 1);
NcVar dom = dataFile.addVar("domain", ncChar, {domdim});
Expand All @@ -4774,6 +4802,11 @@ void FireDomain::loadWindDataInBinary(double refTime){
// Close the file
dataFile.close();

/* for (size_t i = 0; i < globalBMapSizeY; i++) {
delete[] matrix[i];
}*/
delete[] matrix;

}
catch (std::exception const & e)
{
Expand Down
5 changes: 4 additions & 1 deletion src/FireDomain.h
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,8 @@ class FireDomain: public ForeFireAtom, Visitable {
static int registerPropagationModelInstantiator(string, PropagationModelInstantiator);
void updateFuelTable( string , double );
PropagationModel* propModelInstanciation(const int&, string);
vector<string> getPropagationModelKeys(string modelname);

void registerPropagationModel(const int&, PropagationModel*);
bool addPropagativeLayer(string);
size_t getFreePropModelIndex();
Expand Down Expand Up @@ -564,7 +566,8 @@ class FireDomain: public ForeFireAtom, Visitable {

/*! \brief Computing the propagation speed of a given firenode */
double getPropagationSpeed(FireNode*);

vector<string> getFirstPropagationModelKeys() ;

/*! \brief Computing the propagation speed of a given firenode */
double getModelValueAt(int&, FFPoint&, const double&, const double&, const double&);

Expand Down
2 changes: 2 additions & 0 deletions src/SimulationParameters.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,8 @@ SimulationParameters::SimulationParameters(){
parameters.insert(make_pair("InitFile", "Init.ff"));
parameters.insert(make_pair("InitFiles", "output"));
parameters.insert(make_pair("InitTime", "99999999999999"));

parameters.insert(make_pair("LookAheadDistanceForeTimeGradientDataLayer", "40"));
parameters.insert(make_pair("BMapsFiles", "1234567890"));
parameters.insert(make_pair("SHIFT_ALL_POINT_ABSCISSA_BY", "0"));
parameters.insert(make_pair("SHIFT_ALL_POINT_ORDINATES_BY", "0"));
Expand Down
16 changes: 10 additions & 6 deletions src/TimeGradientDataLayer.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,12 +77,16 @@ template<typename T> class TimeGradientDataLayer: public DataLayer<T> {

template<typename T>
T TimeGradientDataLayer<T>::getValueAt(FireNode* fn){
/* Computing the gradient between the next and present location */
T currentValue = fn->getTime();
T nextValue;
FFPoint nextLoc = fn->getLoc() + dx*(fn->getNormal().toPoint());
nextValue = parent->getValueAt(nextLoc,fn->getUpdateTime());
return (nextValue - currentValue)/dx;
/* Computing the gradient between the next and present location */
T currentValue = fn->getTime();
T nextValue;
FFPoint nextLoc = fn->getLoc() + dx*(fn->getNormal().toPoint());
nextValue = parent->getValueAt(nextLoc,fn->getUpdateTime());

// Debug print statement
//std::cout << "currentValue: " << currentValue << ", nextValue: " << nextValue << ", nextLoc: (" << nextLoc.x << ", " << nextLoc.y << "), dx: " << dx << std::endl;

return (nextValue - currentValue)/dx;
}

template<typename T>
Expand Down

0 comments on commit f425339

Please sign in to comment.