Skip to content

Commit

Permalink
Merge pull request #2562 from data-for-change/update-news-flash
Browse files Browse the repository at this point in the history
add option to update_cbs_location_only
  • Loading branch information
atalyaalon authored Feb 14, 2024
2 parents be09df1 + 89a8d1a commit 52f692c
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 20 deletions.
27 changes: 17 additions & 10 deletions anyway/parsers/location_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,18 +506,25 @@ def extract_location_text(text):
return text


def extract_geo_features(db, newsflash: NewsFlash) -> None:
newsflash.location = extract_location_text(newsflash.description) or extract_location_text(
newsflash.title
)
geo_location = geocode_extract(newsflash.location)
if geo_location is not None:
newsflash.lat = geo_location["geom"]["lat"]
newsflash.lon = geo_location["geom"]["lng"]
newsflash.resolution = set_accident_resolution(geo_location)
def extract_geo_features(db, newsflash: NewsFlash, update_cbs_location_only: bool) -> None:
location_from_db = None
if update_cbs_location_only:
location_from_db = get_db_matching_location(
db, newsflash.lat, newsflash.lon, newsflash.resolution, geo_location["road_no"]
db, newsflash.lat, newsflash.lon, newsflash.resolution, newsflash.road1
)
else:
newsflash.location = extract_location_text(newsflash.description) or extract_location_text(
newsflash.title
)
geo_location = geocode_extract(newsflash.location)
if geo_location is not None:
newsflash.lat = geo_location["geom"]["lat"]
newsflash.lon = geo_location["geom"]["lng"]
newsflash.resolution = set_accident_resolution(geo_location)
location_from_db = get_db_matching_location(
db, newsflash.lat, newsflash.lon, newsflash.resolution, geo_location["road_no"]
)
if location_from_db is not None:
for k, v in location_from_db.items():
setattr(newsflash, k, v)
all_resolutions = []
Expand Down
17 changes: 10 additions & 7 deletions anyway/parsers/news_flash.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
news_flash_classifiers = {"ynet": classify_rss, "twitter": classify_tweets, "walla": classify_rss}


def update_all_in_db(source=None, newsflash_id=None):
def update_all_in_db(source=None, newsflash_id=None, update_cbs_location_only=False):
"""
main function for newsflash updating.
Expand All @@ -29,11 +29,14 @@ def update_all_in_db(source=None, newsflash_id=None):
newsflash_items = db.get_all_newsflash()

for newsflash in newsflash_items:
classify = news_flash_classifiers[newsflash.source]
newsflash.organization = classify_organization(newsflash.source)
newsflash.accident = classify(newsflash.description or newsflash.title)
if not update_cbs_location_only:
classify = news_flash_classifiers[newsflash.source]
newsflash.organization = classify_organization(newsflash.source)
newsflash.accident = classify(newsflash.description or newsflash.title)
if newsflash.accident:
extract_geo_features(db, newsflash)
extract_geo_features(
db=db, newsflash=newsflash, update_cbs_location_only=update_cbs_location_only
)
db.commit()


Expand All @@ -47,7 +50,7 @@ def scrape_extract_store_rss(site_name, db):
newsflash.organization = classify_organization(site_name)
if newsflash.accident:
# FIX: No accident-accurate date extracted
extract_geo_features(db, newsflash)
extract_geo_features(db=db, newsflash=newsflash, update_cbs_location_only=False)
newsflash.set_critical()
db.insert_new_newsflash(newsflash)

Expand All @@ -61,7 +64,7 @@ def scrape_extract_store_twitter(screen_name, db):
newsflash.accident = classify_tweets(newsflash.description)
newsflash.organization = classify_organization("twitter")
if newsflash.accident:
extract_geo_features(db, newsflash)
extract_geo_features(db=db, newsflash=newsflash, update_cbs_location_only=False)
newsflash.set_critical()
db.insert_new_newsflash(newsflash)

Expand Down
3 changes: 2 additions & 1 deletion anyway/parsers/news_flash_db_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import logging
import pandas as pd
import numpy as np
from sqlalchemy import desc
from flask_sqlalchemy import SQLAlchemy
from anyway.parsers import infographics_data_cache_updater
from anyway.parsers import timezones
Expand Down Expand Up @@ -87,7 +88,7 @@ def select_newsflash_where_source(self, source):
return self.db.session.query(NewsFlash).filter(NewsFlash.source == source)

def get_all_newsflash(self):
return self.db.session.query(NewsFlash)
return self.db.session.query(NewsFlash).order_by(desc(NewsFlash.date))

def get_latest_date_of_source(self, source):
"""
Expand Down
5 changes: 3 additions & 2 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,14 +67,15 @@ def update_news_flash():
@update_news_flash.command()
@click.option("--source", default="", type=str)
@click.option("--news_flash_id", default="", type=str)
def update(source, news_flash_id):
@click.option("--update_cbs_location_only", is_flag=True)
def update(source, news_flash_id, update_cbs_location_only):
from anyway.parsers import news_flash

if not source:
source = None
if not news_flash_id:
news_flash_id = None
return news_flash.update_all_in_db(source, news_flash_id)
return news_flash.update_all_in_db(source, news_flash_id, update_cbs_location_only)


@update_news_flash.command()
Expand Down

0 comments on commit 52f692c

Please sign in to comment.