Skip to content

Commit

Permalink
update download
Browse files Browse the repository at this point in the history
  • Loading branch information
zhuwq0 committed Nov 23, 2023
1 parent 4ef4932 commit dbc0274
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 14 deletions.
39 changes: 39 additions & 0 deletions slurm/create_filelist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# %%
import os
from glob import glob

# %%
protocol = "file"
token = None

## get from command line
root_path = "local"
region = "demo"
if len(os.sys.argv) > 1:
root_path = os.sys.argv[1]
region = os.sys.argv[2]
print(f"root_path: {root_path}")
print(f"region: {region}")

# %%
result_path = f"{region}/phasenet_das"
if not os.path.exists(f"{root_path}/{result_path}"):
os.makedirs(f"{root_path}/{result_path}", exist_ok=True)

# %%
folder_depth = 2
csv_list = sorted(glob(f"{root_path}/{result_path}/picks_phasenet_das/????-??-??/*.csv"))
csv_list = ["/".join(x.split("/")[-folder_depth:]) for x in csv_list]

# %%
hdf5_list = sorted(glob(f"{root_path}/{region}/????-??-??/*.h5"))
num_to_process = 0
with open(f"{root_path}/{result_path}/filelist.csv", "w") as fp:
# fp.write("\n".join(hdf5_list))
for line in hdf5_list:
csv_name = "/".join(line.split("/")[-folder_depth:]).replace(".h5", ".csv")
if csv_name not in csv_list:
fp.write(f"{line}\n")
num_to_process += 1

print(f"filelist.csv created in {root_path}/{result_path}: {num_to_process} / {len(hdf5_list)} to process")
10 changes: 7 additions & 3 deletions slurm/download_waveform_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,9 +158,13 @@ def download(

except Exception as err:
err = str(err).rstrip("\n")
message = "No data available for request"
if err[: len(message)] == message:
print(f"{message} from {client.base_url}: {starttime.isoformat()} - {endtime.isoformat()}")
message1 = "No data available for request"
message2 = "The current client does not have a dataselect service"
if err[: len(message1)] == message1:
print(f"{message1} from {client.base_url}: {starttime.isoformat()} - {endtime.isoformat()}")
break
elif err[: len(message2)] == message2:
print(f"{message2} from {client.base_url}: {starttime.isoformat()} - {endtime.isoformat()}")
break
else:
print(f"Error occurred from {client.base_url}:{err}. Retrying...")
Expand Down
25 changes: 14 additions & 11 deletions slurm/run_phasenet_das.yaml
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
name: quakeflow

workdir: ./
workdir: .

num_nodes: 16
num_nodes: 1

resources:

Expand All @@ -12,14 +12,13 @@ resources:

zone: us-west1-b

# instance_type:
# instance_type: n2-highmem-16

accelerators: V100:1
# accelerators: V100:1

cpus: 4+
cpus: 8+

use_spot: True
# spot_recovery: none

# image_id: docker:zhuwq0/quakeflow:latest

Expand All @@ -32,8 +31,8 @@ envs:
file_mounts:

/data:
# source: s3://scedc-pds
# source: gs://quakeflow_dataset
# source: s3://scedc-pds/
# source: gs://quakeflow_dataset/
# source: gs://quakeflow_share/
source: gs://das_arcata/
mode: MOUNT
Expand All @@ -45,7 +44,7 @@ file_mounts:
~/.ssh/id_rsa.pub: ~/.ssh/id_rsa.pub
~/.ssh/id_rsa: ~/.ssh/id_rsa
~/.config/rclone/rclone.conf: ~/.config/rclone/rclone.conf
~/EQNet: ../EQNet
# EQNet: ../EQNet

setup: |
echo "Begin setup."
Expand All @@ -55,22 +54,26 @@ setup: |
pip3 install cartopy
pip3 install h5py tqdm wandb
pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
# pip3 install torch torchvision torchaudio
# mkdir ~/data && rclone mount range:/ ~/data --daemon
run: |
[ -d "EQNet" ] && rm -r "EQNet"
git clone https://github.com/AI4EPS/EQNet.git
num_nodes=`echo "$SKYPILOT_NODE_IPS" | wc -l`
master_addr=`echo "$SKYPILOT_NODE_IPS" | head -n1`
[[ ${SKYPILOT_NUM_GPUS_PER_NODE} -gt $NCPU ]] && nproc_per_node=${SKYPILOT_NUM_GPUS_PER_NODE} || nproc_per_node=$NCPU
if [ "${SKYPILOT_NODE_RANK}" == "0" ]; then
ls -al /data
python create_filelist.py
python create_filelist.py ${ROOT_PATH} ""
fi
torchrun \
--nproc_per_node=${nproc_per_node} \
--node_rank=${SKYPILOT_NODE_RANK} \
--nnodes=$num_nodes \
--master_addr=$master_addr \
--master_port=8008 \
../EQNet/predict.py --model phasenet_das --format=h5 --data_list=${ROOT_PATH}/${RESULT_PATH}/filelist.csv --result_path=${ROOT_PATH}/${RESULT_PATH} --batch_size 1 --workers 8 --system optasense
EQNet/predict.py --model phasenet_das --format=h5 --data_list=${ROOT_PATH}/${RESULT_PATH}/filelist.csv --result_path=${ROOT_PATH}/${RESULT_PATH} --batch_size 1 --workers 6 --folder_depth=2 --system optasense

0 comments on commit dbc0274

Please sign in to comment.