From 0711ce89763124f953b655224c99f466a2f54894 Mon Sep 17 00:00:00 2001 From: xuetao chen <cxt@mail.ustc.edu.cn> Date: Fri, 19 Jan 2024 11:24:13 +0800 Subject: [PATCH] initial --- .idea/.gitignore | 8 + .idea/FourCastNetTEC.iml | 8 + .../inspectionProfiles/profiles_settings.xml | 6 + .idea/misc.xml | 4 + .idea/modules.xml | 8 + .idea/vcs.xml | 6 + config/AFNO.bak.yaml | 131 ++++ config/AFNO.yaml | 195 ++++++ networks/__pycache__/afnonet.cpython-38.pyc | Bin 0 -> 9351 bytes networks/afnonet.py | 283 ++++++++ train.py | 616 ++++++++++++++++++ 11 files changed, 1265 insertions(+) create mode 100644 .idea/.gitignore create mode 100644 .idea/FourCastNetTEC.iml create mode 100644 .idea/inspectionProfiles/profiles_settings.xml create mode 100644 .idea/misc.xml create mode 100644 .idea/modules.xml create mode 100644 .idea/vcs.xml create mode 100644 config/AFNO.bak.yaml create mode 100644 config/AFNO.yaml create mode 100644 networks/__pycache__/afnonet.cpython-38.pyc create mode 100644 networks/afnonet.py create mode 100644 train.py diff --git a/.idea/.gitignore b/.idea/.gitignore new file mode 100644 index 0000000..13566b8 --- /dev/null +++ b/.idea/.gitignore @@ -0,0 +1,8 @@ +# Default ignored files +/shelf/ +/workspace.xml +# Editor-based HTTP Client requests +/httpRequests/ +# Datasource local storage ignored files +/dataSources/ +/dataSources.local.xml diff --git a/.idea/FourCastNetTEC.iml b/.idea/FourCastNetTEC.iml new file mode 100644 index 0000000..bc6c1cf --- /dev/null +++ b/.idea/FourCastNetTEC.iml @@ -0,0 +1,8 @@ +<?xml version="1.0" encoding="UTF-8"?> +<module type="PYTHON_MODULE" version="4"> + <component name="NewModuleRootManager"> + <content url="file://$MODULE_DIR$" /> + <orderEntry type="jdk" jdkName="castnet (3)" jdkType="Python SDK" /> + <orderEntry type="sourceFolder" forTests="false" /> + </component> +</module> \ No newline at end of file diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml new file mode 100644 index 0000000..105ce2d --- /dev/null +++ b/.idea/inspectionProfiles/profiles_settings.xml @@ -0,0 +1,6 @@ +<component name="InspectionProjectProfileManager"> + <settings> + <option name="USE_PROJECT_PROFILE" value="false" /> + <version value="1.0" /> + </settings> +</component> \ No newline at end of file diff --git a/.idea/misc.xml b/.idea/misc.xml new file mode 100644 index 0000000..466135d --- /dev/null +++ b/.idea/misc.xml @@ -0,0 +1,4 @@ +<?xml version="1.0" encoding="UTF-8"?> +<project version="4"> + <component name="ProjectRootManager" version="2" project-jdk-name="castnet (3)" project-jdk-type="Python SDK" /> +</project> \ No newline at end of file diff --git a/.idea/modules.xml b/.idea/modules.xml new file mode 100644 index 0000000..dffda97 --- /dev/null +++ b/.idea/modules.xml @@ -0,0 +1,8 @@ +<?xml version="1.0" encoding="UTF-8"?> +<project version="4"> + <component name="ProjectModuleManager"> + <modules> + <module fileurl="file://$PROJECT_DIR$/.idea/FourCastNetTEC.iml" filepath="$PROJECT_DIR$/.idea/FourCastNetTEC.iml" /> + </modules> + </component> +</project> \ No newline at end of file diff --git a/.idea/vcs.xml b/.idea/vcs.xml new file mode 100644 index 0000000..35eb1dd --- /dev/null +++ b/.idea/vcs.xml @@ -0,0 +1,6 @@ +<?xml version="1.0" encoding="UTF-8"?> +<project version="4"> + <component name="VcsDirectoryMappings"> + <mapping directory="" vcs="Git" /> + </component> +</project> \ No newline at end of file diff --git a/config/AFNO.bak.yaml b/config/AFNO.bak.yaml new file mode 100644 index 0000000..72c8fb9 --- /dev/null +++ b/config/AFNO.bak.yaml @@ -0,0 +1,131 @@ +### base config ### +full_field: &FULL_FIELD + loss: 'l2' + lr: 1E-3 + scheduler: 'ReduceLROnPlateau' + num_data_workers: 4 + dt: 1 # how many timesteps ahead the model will predict + n_history: 0 #how many previous timesteps to consider + prediction_type: 'iterative' + prediction_length: 41 #applicable only if prediction_type == 'iterative' + n_initial_conditions: 5 #applicable only if prediction_type == 'iterative' + ics_type: "default" + save_raw_forecasts: !!bool True + save_channel: !!bool False + masked_acc: !!bool False + maskpath: None + perturb: !!bool False + add_grid: !!bool False + N_grid_channels: 0 + gridtype: 'sinusoidal' #options 'sinusoidal' or 'linear' + roll: !!bool False + max_epochs: 50 + batch_size: 64 + + #afno hyperparams + num_blocks: 8 + nettype: 'afno' + patch_size: 8 + width: 56 + modes: 32 + #options default, residual + target: 'default' + in_channels: [0,1] + out_channels: [0,1] #must be same as in_channels if prediction_type == 'iterative' + normalization: 'zscore' #options zscore (minmax not supported) + train_data_path: '/pscratch/sd/j/jpathak/wind/train' + valid_data_path: '/pscratch/sd/j/jpathak/wind/test' + inf_data_path: '/pscratch/sd/j/jpathak/wind/out_of_sample' # test set path for inference + exp_dir: '/pscratch/sd/j/jpathak/ERA5_expts_gtc/wind' + time_means_path: '/pscratch/sd/j/jpathak/wind/time_means.npy' + global_means_path: '/pscratch/sd/j/jpathak/wind/global_means.npy' + global_stds_path: '/pscratch/sd/j/jpathak/wind/global_stds.npy' + + orography: !!bool False + orography_path: None + + log_to_screen: !!bool True + log_to_wandb: !!bool True + save_checkpoint: !!bool True + + enable_nhwc: !!bool False + optimizer_type: 'FusedAdam' + crop_size_x: None + crop_size_y: None + + two_step_training: !!bool False + plot_animations: !!bool False + + add_noise: !!bool False + noise_std: 0 + +afno_backbone: &backbone + <<: *FULL_FIELD + log_to_wandb: !!bool True + lr: 5E-4 + batch_size: 2 + max_epochs: 150 + scheduler: 'CosineAnnealingLR' + in_channels: [0, 1 ,2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19] + out_channels: [0, 1 ,2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19] + orography: !!bool False + orography_path: None + exp_dir: '/mnt/data/work/FourCastNet/results/nowcasting' + train_data_path: '/mnt/data/work/FourCastNet/data/train' + valid_data_path: '/mnt/data/work/FourCastNet/data/test' + inf_data_path: '/mnt/data/work/FourCastNet/data/out_of_sample' + time_means_path: '/mnt/data/work/FourCastNet/data/time_means.npy' + global_means_path: '/mnt/data/work/FourCastNet/data/global_means.npy' + global_stds_path: '/mnt/data/work/FourCastNet/data/global_stds.npy' + +afno_backbone_orography: &backbone_orography + <<: *backbone + orography: !!bool True + orography_path: '/pscratch/sd/s/shas1693/data/era5/static/orography.h5' + +afno_backbone_finetune: + <<: *backbone + lr: 1E-4 + batch_size: 64 + log_to_wandb: !!bool True + max_epochs: 50 + pretrained: !!bool True + two_step_training: !!bool True + pretrained_ckpt_path: '/pscratch/sd/s/shas1693/results/era5_wind/afno_backbone/0/training_checkpoints/best_ckpt.tar' + +perturbations: + <<: *backbone + lr: 1E-4 + batch_size: 64 + max_epochs: 50 + pretrained: !!bool True + two_step_training: !!bool True + pretrained_ckpt_path: '/pscratch/sd/j/jpathak/ERA5_expts_gtc/wind/afno_20ch_bs_64_lr5em4_blk_8_patch_8_cosine_sched/1/training_checkpoints/best_ckpt.tar' + prediction_length: 24 + ics_type: "datetime" + n_perturbations: 100 + save_channel: !bool True + save_idx: 4 + save_raw_forecasts: !!bool False + date_strings: ["2018-01-01 00:00:00"] + inference_file_tag: " " + valid_data_path: "/pscratch/sd/j/jpathak/ " + perturb: !!bool True + n_level: 0.3 + +### PRECIP ### +precip: &precip + <<: *backbone + in_channels: [0, 1 ,2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19] + out_channels: [0] + nettype: 'afno' + nettype_wind: 'afno' + log_to_wandb: !!bool True + lr: 2.5E-4 + batch_size: 64 + max_epochs: 25 + precip: '/pscratch/sd/p/pharring/ERA5/precip/total_precipitation' + time_means_path_tp: '/pscratch/sd/p/pharring/ERA5/precip/total_precipitation/time_means.npy' + model_wind_path: '/pscratch/sd/s/shas1693/results/era5_wind/afno_backbone_finetune/0/training_checkpoints/best_ckpt.tar' + precip_eps: !!float 1e-5 + diff --git a/config/AFNO.yaml b/config/AFNO.yaml new file mode 100644 index 0000000..765146c --- /dev/null +++ b/config/AFNO.yaml @@ -0,0 +1,195 @@ +### base config ### +full_field: &FULL_FIELD + loss: 'l2' + lr: 1E-3 + scheduler: 'ReduceLROnPlateau' + num_data_workers: 4 + dt: 1 # how many timesteps ahead the model will predict + n_history: 0 #how many previous timesteps to consider + prediction_type: 'iterative' + prediction_length: 41 #applicable only if prediction_type == 'iterative' + n_initial_conditions: 5 #applicable only if prediction_type == 'iterative' + ics_type: "default" + save_raw_forecasts: !!bool True + save_channel: !!bool False + masked_acc: !!bool False + maskpath: None + perturb: !!bool False + add_grid: !!bool False + N_grid_channels: 0 + gridtype: 'sinusoidal' #options 'sinusoidal' or 'linear' + roll: !!bool False + max_epochs: 50 + batch_size: 64 + + #afno hyperparams + num_blocks: 8 + nettype: 'afno' + patch_size: 8 + width: 56 + modes: 32 + #options default, residual + target: 'default' + in_channels: [0,1] + out_channels: [0,1] #must be same as in_channels if prediction_type == 'iterative' + normalization: 'zscore' #options zscore (minmax not supported) + train_data_path: '/pscratch/sd/j/jpathak/wind/train' + valid_data_path: '/pscratch/sd/j/jpathak/wind/test' + inf_data_path: '/pscratch/sd/j/jpathak/wind/out_of_sample' # test set path for inference + exp_dir: '/pscratch/sd/j/jpathak/ERA5_expts_gtc/wind' + time_means_path: '/pscratch/sd/j/jpathak/wind/time_means.npy' + global_means_path: '/pscratch/sd/j/jpathak/wind/global_means.npy' + global_stds_path: '/pscratch/sd/j/jpathak/wind/global_stds.npy' + + orography: !!bool False + orography_path: None + + log_to_screen: !!bool True + log_to_wandb: !!bool True + save_checkpoint: !!bool True + + enable_nhwc: !!bool False + optimizer_type: 'FusedAdam' + crop_size_x: None + crop_size_y: None + + two_step_training: !!bool False + plot_animations: !!bool False + + add_noise: !!bool False + noise_std: 0 + + +afno_backbone_cxtwin: + <<: *FULL_FIELD + log_to_wandb: !!bool True + lr: 5E-4 + batch_size: 16 + max_epochs: 1510 + scheduler: 'CosineAnnealingLR' + in_channels: [0, 1 ,2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29] + out_channels: [0, 1 ,2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29] +# in_channels: [0, 1, 2, 3, 4, 5] +# out_channels: [0, 1, 2, 3, 4, 5] + orography: !!bool False + orography_path: None + exp_dir: '/home/cxt/work/fourcastnet/FourCastNetwithobswin/results/ljkj' + train_data_path: '/home/cxt/work/fourcastnet/FourCastNetwithobswin/data_ljkj/train' + valid_data_path: '/home/cxt/work/fourcastnet/FourCastNetwithobswin/data_ljkj/test' + inf_data_path: '/home/cxt/work/fourcastnet/FourCastNetwithobswin/data_ljkj/out_of_sample' + time_means_path: '/home/cxt/work/fourcastnet/FourCastNetwithobswin/data_ljkj/time_means.npy' + global_means_path: '/home/cxt/work/fourcastnet/FourCastNetwithobswin/data_ljkj/global_means.npy' + global_stds_path: '/home/cxt/work/fourcastnet/FourCastNetwithobswin/data_ljkj/global_stds.npy' + +afno_backbone_cxt: &backbone_cxt + <<: *FULL_FIELD + log_to_wandb: !!bool True + lr: 5E-4 + batch_size: 2 + max_epochs: 150 + scheduler: 'CosineAnnealingLR' +# in_channels: [0, 1 ,2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19] +# out_channels: [0, 1 ,2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19] + in_channels: [0, 1, 2, 3, 4, 5] + out_channels: [0, 1, 2, 3, 4, 5] + orography: !!bool False + orography_path: None + exp_dir: '/mnt/data/work/FourCastNet_linux/results/era5_wind' + train_data_path: '/mnt/data/work/FourCastNet_linux/data_nowcasting/train' + valid_data_path: '/mnt/data/work/FourCastNet_linuxdata_nowcasting/test' + inf_data_path: '/mnt/data/work/FourCastNet_linux/data_nowcasting/out_of_sample' + time_means_path: '/mnt/data/work/FourCastNet_linux/data_nowcasting/time_means.npy' + global_means_path: '/mnt/data/work/FourCastNet_linux/data_nowcasting/global_means.npy' + global_stds_path: '/mnt/data/work/FourCastNet_linux/data_nowcasting/global_stds.npy' + +afno_backbone_nowcasting: + <<: *FULL_FIELD + log_to_wandb: !!bool True + lr: 5E-4 + batch_size: 2 + max_epochs: 150 + scheduler: 'CosineAnnealingLR' +# in_channels: [0, 1 ,2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19] +# out_channels: [0, 1 ,2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19] + in_channels: [0] + out_channels: [0] + orography: !!bool False + orography_path: None + exp_dir: '/mnt/data/work/FourCastNet/results/nowcasting' + train_data_path: '/mnt/data/work/FourCastNet/data_nowcasting/train' + valid_data_path: '/mnt/data/work/FourCastNet/data_nowcasting/test' + inf_data_path: '/mnt/data/work/FourCastNet/data_nowcasting/out_of_sample' + time_means_path: '/mnt/data/work/FourCastNet/data_nowcasting/time_means.npy' + global_means_path: '/mnt/data/work/FourCastNet/data_nowcasting/global_means.npy' + global_stds_path: '/mnt/data/work/FourCastNet/data_nowcasting/global_stds.npy' + +afno_backbone: &backbone + <<: *FULL_FIELD + log_to_wandb: !!bool True + lr: 5E-4 + batch_size: 64 + max_epochs: 150 + scheduler: 'CosineAnnealingLR' + in_channels: [0, 1 ,2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19] + out_channels: [0, 1 ,2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19] + orography: !!bool False + orography_path: None + exp_dir: '/pscratch/sd/s/shas1693/results/era5_wind' + train_data_path: '/pscratch/sd/s/shas1693/data/era5/train' + valid_data_path: '/pscratch/sd/s/shas1693/data/era5/test' + inf_data_path: '/pscratch/sd/s/shas1693/data/era5/out_of_sample' + time_means_path: '/pscratch/sd/s/shas1693/data/era5/time_means.npy' + global_means_path: '/pscratch/sd/s/shas1693/data/era5/global_means.npy' + global_stds_path: '/pscratch/sd/s/shas1693/data/era5/global_stds.npy' + +afno_backbone_orography: &backbone_orography + <<: *backbone + orography: !!bool True + orography_path: '/pscratch/sd/s/shas1693/data/era5/static/orography.h5' + +afno_backbone_finetune: + <<: *backbone + lr: 1E-4 + batch_size: 64 + log_to_wandb: !!bool True + max_epochs: 50 + pretrained: !!bool True + two_step_training: !!bool True + pretrained_ckpt_path: '/pscratch/sd/s/shas1693/results/era5_wind/afno_backbone/0/training_checkpoints/best_ckpt.tar' + +perturbations: + <<: *backbone + lr: 1E-4 + batch_size: 64 + max_epochs: 50 + pretrained: !!bool True + two_step_training: !!bool True + pretrained_ckpt_path: '/pscratch/sd/j/jpathak/ERA5_expts_gtc/wind/afno_20ch_bs_64_lr5em4_blk_8_patch_8_cosine_sched/1/training_checkpoints/best_ckpt.tar' + prediction_length: 24 + ics_type: "datetime" + n_perturbations: 100 + save_channel: !bool True + save_idx: 4 + save_raw_forecasts: !!bool False + date_strings: ["2018-01-01 00:00:00"] + inference_file_tag: " " + valid_data_path: "/pscratch/sd/j/jpathak/ " + perturb: !!bool True + n_level: 0.3 + +### PRECIP ### +precip: &precip + <<: *backbone + in_channels: [0, 1 ,2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19] + out_channels: [0] + nettype: 'afno' + nettype_wind: 'afno' + log_to_wandb: !!bool True + lr: 2.5E-4 + batch_size: 64 + max_epochs: 25 + precip: '/pscratch/sd/p/pharring/ERA5/precip/total_precipitation' + time_means_path_tp: '/pscratch/sd/p/pharring/ERA5/precip/total_precipitation/time_means.npy' + model_wind_path: '/pscratch/sd/s/shas1693/results/era5_wind/afno_backbone_finetune/0/training_checkpoints/best_ckpt.tar' + precip_eps: !!float 1e-5 + diff --git a/networks/__pycache__/afnonet.cpython-38.pyc b/networks/__pycache__/afnonet.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..52e1f283c268f643e47d3571e726bd962162820d GIT binary patch literal 9351 zcmb7KYiuM}R<2vGu70@P<MDeQ=}fYj>B+>NG0RJY$zvu7fvhtmSy(l@)VBLpySrWe zm|NvI?rjUSp2&_C!=l}Vg;h{nLcj=E_`&j<U;INt{Hz}c#E-~ANbm!#6p0PrIaU4e z%;Z(Ad+NUH*16}L`<+w$<*BKRhU<;hKl8d*HSIsBFnQT1JdYF{3Qd!m6uwrbyY7pY zUe*QW4c};)Ws~!!Z?){Q-Aa{HT(*3tl`f|_Z~K{6ww&d>@N?xUo-bd{^L@HiC>N5p zGp&NI39U6#KFc*Jf3`JOo@>pQ=aU`_Jj(GGTTA67u1ovNt#jpboX_~@TNla~S{KU~ zxt#SMYZc2y&gcBcTbIh0g!YOir)2(~Ci7mQ`NY0aeiAj)vVfXGQu77WoRKrAnMrD% zLd{t@i<;S_<}zyL<UDHTlbS22S&)mUSxjn5-c`A@rGHzfd)^hfEYHdF@`Aj0WR{=y z@`G#gF_eq)@nf_6jQ6zn%)ThEc&pw@^ZLFJ8glu;4R7UGP)Uk6wbG>%+N1T7h*I5( z3L6z4+3f2|ddicx8`ZF+N7hS9byVa?&+Aq@-Cgt$Z-N-zibS(>TXnj3Dq%gE3e|SI z>b5(oRq<W)$h_&@-S*nav~$(ESKaD%8tu>x#tnLuR$fJ^N_*3bY~^*mN{F%PH-?oc zOx^KRqaz#DI~BPmtF-3H%S;Nvb!0rPETmS}rI5linx@oi=7Cld(var8blH-YwC`zU z`#>|batb_1MaGx>?)nGA8LK3!!;A7Xy@3?0gM`|?HW0@Us=m=J^u?ah*Sm|j>!Cg{ z`l7EN>wV2N`{ureyT!G7-{4vsrBvTK7HDy#_NI1Ute4V}6>N7saM*Pl?MCRjk=|}c z_G^tcSP&Vt>W#>#R71+GMJZx;XFH5ciA>3k%)s+&QMS=`YvA0r@`7l(-jLFpsK{Z! zxH*l<T)(pGDRmBSbRW!i-Kt*+0@uB-eRzAd-f4NO)px?xosQaCt#!6lwGxDFFI;^Y z`CG`Zd*Mzatamnookn{VEz~nut<>5b6jr*sY8Epu64UY^n#c$TzdWa24*XsjzHC*e z4#`D%V)ZSg;AcSk+Cb}zW3dO84a6v?k8)sn(mJezT&o!jEtzToH$|)97gbBh-PhH5 z+#>OgdV(65JV`aFT1V|vlzbe8AO}LL9{ze+T65B)<j`KVD=iOum2urxM{fI+&$;g1 z?TR0_s5zR|eC4IrzRcWEm#I9Ju)6B`zUxk`fgY=9A*ujED-THU9I=?@nkovXE^4JK zcnV}B^$yFrDJ{slA#It0tV8}G>lS1^9oa9uy#D&y?Gp!sKKRD$4`2EB-o58eV3W0n z<em(=ffRfQG9mYn@p+h!Ej_pz>Jn`K0Zeya|AF?txi7wnO@eeAeS@`tbiwLtO_Z9} zGHHa}H$kBVppK+`I?N2RXitZ^!IY%ll<Xuqv)ZwVM6>2UwkP#6l3vrF)Qe^v%~yDI z&(Tb0NJmTdx+4Dca*0|6jh<KRT`C6k&bBX$8(vX1-fjdNzE|AXEw;B??uOr~ZUx1J zzu$fCgygz3qt4>SvJmB82m%iVx6_WTYj$GpXeNL$4jSRE8`h!E^^Py2%y`}?!}Ghe zj%c|KksbGIv^U+Ff`Vh_Xl7zkw}wiXaVw}+d@q*G^qq==5Qnfm7TT)y$buP`?MUCb z5$PL9J8QVFm2B1-^+j4fX{GvUB0opu=ZUcFjx>qi@M+Ws%OIphMco$k%QLn0tKunf zRrID~S$wo9AiI%8x`h;cA96@GwlA6*EUq9643i+5`d7_8bI;nd_fjE@JtPNaU+U60 zGN^q((q{E-X)bHhT7djDQ%KISwI|R&{peQo$joY+n5&tU_L08le2%_p?u(U;MrR(X z$@Z<`iuQ70uAjpbr)IT%u{Raw2h*W3D99Av1fb|$9j#M~8qo!mq}HE$SFagMTJsD> z%}8enlK&Y;oaGS~jlhgrbCyOgkZ$2L`uJx(t9_1D#n1RWtBT))SH(LtE0O0_r9Wj= zF`qth?Mw8YbEmlUiL)mC$2^*kdDPE*rnP*QxhHw^S=K)0%zVt5e)iMW(p>0g`q@og zrey~5E5dve7zLbIl36YrgXQGld{{?P$0los**u3a(4HEBDP87{P!F9t4>UsuGR}bo zvX4{iC0!BpQfnZQ)lonOK!%2Fv?7z>AUALqfH_-wZ#Sx5uduSRveCG<(dayLld>JQ zvL_$mDhWkukIt4~X>G`o5s84S)NQK25Lq(Z?Rt?_^E;IgMy*}}f_j!7Z`5k2P@n)y zWId^cjn>cd>Uy5)1hLtAiTeYOvMGzi%aN%(e><|hMmyNX+5t>`b(4CTjaFq-J%_A% zp2!PSX$4`Wx)o)DPAv@Ts?pwJ!x&A!-SBqYO5j#It*-CA6WNV;Jti%>l#i_`9ovz) z(WnIKRmzGNBk@a-`1MHKDyKR(TzWWX=+Vfm4Qtk1uE~a-5I7s)!wbrpEw3BK`3X$H zaDre5aV98(Xu^ao1`^X}#6|s@xF%-AjFCgx6l-GbBVBiXqJMPBK2=Ix7uS=P;$+^Q z5d~QHMf@^)FP~WRAxJXS3ZAD{iLi7Kxo@j0R8b-VKQTTI85*f;ROC~H2_Gu_2G5{@ z`X+!Ve@0PF73tKKPD%m2?FbN&<m%)IAX(%DkONRlGfYZJ2QbMdEdfkQF-&qG(b>p) zk&f`hCn!9x&S1*7#xImK{P%NJI--Ubfktu|45Xk7G69<GKD;hK7(FB>!Re>_nSBC3 za1;QgO!dXkg3BP<c4!Y$#{`s|{%{Nc5-A%W5ZN$`JPwe)A@w7gA=k%|=E4WY!i{xa zDWLBKJYHQSQY1n^AA&aqAQtK64Q6Cy*BU-RfFev7BLnAE_zxivNR%cw!|jrP$3xcC z9qLs8iHw%tjk2<{4d~+rTa9ih7mM~SYE7fnRTUtS7=3TjxFWS==-G(}(xlqcn4+Ag zkqLHz4S5v2P#6}_BD5)j=moK?XG9Ld+{^RMJsMNc#!#a51#f{crg)@_(bUz|*us&e z1M<+wyGvXRbCd%{7~4YXIdnd3B{9G9i10hsDc2xEj3^lsM!<oYYDhQ^6Vg9Y-S^Nt zm<NIC;!QKGz9!z5evar(#qT@Ry7YtkMP7&^G!~^_BcE@b>@j&~pQ3aXbh(Sf&vr~* z(9gs#&guV~Ur^MiigfBC2GbSLD-6|;)eWk6&w*D$p7^9!VnyjY%BwcIa4W0B;WP#1 zNhs+a5CbO^ij@f%3%k-ZYUF27jue7$8@vbj7*y+oX&I|tCd@VoX%PZU!76Y_Kw3xC zcPh*d@~n2#L(G32&pV+f^SSg^$&9jFo@#r347}|iR1N7-;6zHmZuJ}JS}IJam<<$P zX$4UR+EcB^qo&s3<Ga;*rHzS#$Vm!8lwWtrNguVOhb45{z$vo<$68I}5s}%&6Js^C zZ*|&lugS=);vSiAcyKkz-5p9-NUQG=VPpCP<yf$2|NjUnAV--HI$?A2&~4<;i{8v= zKOU_}r~GuX+n?D_A1R0#FYFr_toDeopZ*2P@jAzgy^Oj)MG7(?W2d}#X1vHm`qkIb zVPg3B2<f2V`$)+VlFKTpsUn@a*zHG6@G{6ycI6yPZ&v2mr{-8++4QWY4Kd<4AL49I z9N#Oa<EUOa!-q*0HhDfu(a8wezOSGB1LkcYeuy~437IdoPlRyO2}$`0EkJD$p%W)= zU8Ni?waT#+WayekB8FfTY#C;p)Qv;i)%!YVtP{|S8kJzTnilRjSehx^okQK!a8D!Q zh-qeV&xW~hs+lLmcci~(h0`=o(?qfl$*)4rEud6rp4rQXQ-hiC>|hpteEv{B1};SF z9LCQd3TPIU7UCcTdM!!^s5EtK+=i#~p4Qh1S<~EfU#4{|q0UU|aOb*ZJmcJfb|4P) z1LMFvunz2lRL#Meox10Q=liLBLG$sGFC1Ej`aV7NBG&NOA+6jQK4CU~7UrGRWCl3c zK)+)6INIlAPEH*m2z_V|E}{IyF$GBJK+Ugj9{tyk20#ACzr6`GYIxlMcED-0Hd&q3 zHkhw=K%%tQ+VG@H#}ye4bQawzc4c&t%9u!RA_b&>n^2ms<rq<G$J+1fcTJofPoM!~ z%|__80HO7j>xyvVgiMqEMww^*Mi3Iby!lTkvsBxfuJxWwL~W$DE6)PRRQ%v3GQ%Nk z&>mVGi5}sD;3^0&?h@Pp4NNo-#UYMLYz%~_DGpPIhSZLnc*!MEa$;z{g>e%{y-&r{ zCp-Lv%CHvsZy_Iv#(g>yx+;MhKn$Kbq+T3f;7NHZ3fskj2Vzoe$9@O7e?j7Soo(np zDF_p6&_8+^E(955bpzx+b}k5I>B+HiR%{v%s0Z)i+mxe?RPPYk1&MMLav26y<0bLy z&`z@bRc-2)LGZg;e~Gp~%EqRiWg*Jac&_jQ^$RqJfT?<&2Bo{5z~yd=f+UecWrA9f z<HPyrR;pek>XF!qGGF58;%nGE4$n}OoZ~+M`$CGKL1flFoCySgBdgNw`n!lTax8hr zYi!oTpfnxBjR)SEUzpyN-wsX{oQS$&mj51V@o8iT{s_n=ZqzvT^F12*IFa8aGCEoQ z3Z;O+zjlUBmf_iwW4~iwIQl&NjYWM{%<AV2`~NalM!%xxMX&In@IDk9CIz}hvlIRc zA(J5jE&IB9kx(@Z2^%dj+0rJHC$+7#YNIc(HJ1^N7h3yTGJ}H|9%>!+Hl<0u*&JBw zEKus(csy$n+hv4RGmQY-7FOETIJU;y5^snd?*WS=6dMu#D2?-)I)^Tc5LAJLI})u2 z^jr;4{jW$#6dit=m<Pzs0r2+bM%<d*=~HS=_J^dESw=vL7*y9|N8t}pE~l%0;D((o zubr5lSWZ~OMrsyKw>xh97OwQFmE8|8G@yw!y|*x2#shs$?H%nKc}BdDUqc4FfW<6n z46Fg06A-O65mJCj9D{8!FbT$uLc$!JK-sxX#*@O#0C#)h?rnnJB|Q$+meT51@DTM~ zBJ`f>4~eW3A=sn-m<Su0_mP9U#Kwg!?N?DzHa7gN2lgT=Bp_pC68;lCV;_XEIAQa7 z%IUquAx|D0a*8R$KvG#|%2$wi6o?8DN0bz~fxE}CDZppG<HJ@`#+@ST5dXM^dm!Fv z9@()+QnJ>2*EWjv;!g2O_eQb1R;&`(EES)*S=@;8#Vhq<_eQCBWe3frShV>ZGb6o= zKyY_$%(#-J{sOC2e?lZjgiR<L(~rY%KSsC2iy^-y_8|yxBK+{7?ZlyAK%~(pmYn)+ zn)#bV2p+}c9_8L8vO$C~(j$Np{Lm}ZULr!rS?ur(C`UFpGMdmryRivJ)Kh;>)%-@6 zDR+&?SOM6+v1$7)H2ftJJ2sAmPk@j5{ZD?)Ndrb3PEOH?D$=Qobb+qZ-)-W61bv&N z?=}>enDme`>N^+#%=mA(sUH$y&U7h9OsQHR!;3VHl7Yl-?j$rMtpm&gh0wqpaEV1| zR@;N?1ZN1)5!B#MG-e2#3oRopWcY`c6x%{`Jr3)Tk%0xfi<=^|3^+hAB4#-Y1&aqC z7IW0l2Z0x(I?dsSg>f%FkytUgKQo{Af_H#mla1$%KF<~Tc(FzA?D!BKL7!8~M*JYS zd<!X{V~{KmJ|5_M27N%Fk5Wx2+LE?Kjz4{XV&32**H`e73-6DQT=<j#8n>Y**ChRr zx8G0m=dZ+e4NRFj76S{8X6hIrShi12Y@e8My_r|r-R-bQ!91_XAC;~|;%aZ9RFoYr zXkQMCEkYJWitzcxD?Pom!dd{Xv2|mv{S;nB{S64dU;33W^x9E6#8;dEfWjN{cVyU< zOWnUl<Tr^hYsdU0`Og!79b_~dHouI}9-P4uz!|Bhss5VApB6&(b*dZB$-BoJN78H2 z#3cNPJ~}MV^cA&LfeY-G^pk(q;r<LD#>i7I_4h<bz)xu3V51+962>iLqtxFK`4M$m zi%jsW9;IveCqmfi_<_Q-T9mDJeBa}*Vu3QK1}-zcQuBjIbs`;xfMldR1Ju^tUGW<m zEA+2~U2?Wl)WhNie4FpaHpJmiW-IMBW5M#-VR40@jqh(2Kgx{r>Q{O8u+eI*@Qwy6 z{0%;c!)|mCl+*DXh+UVLN7JD*zAgtVzziXng*YAYBbtvpY=;dWZKJ<ds4O*71iRA> zV(UCbQ=g4%l4q{)V3nu7)A1DK({L`1GT1Tf<Uk)s31j5m!3Z{Myn$?GBa2^0{USd& zs8Gn8{>Tw*`#`K+gvG;HUT@Rbv2={YE^?lY-|;5xOhB;T0mLDYLzGSlME3Lo-;-Kn U=IqQXGnYZjGX;wOQENv0KV@JWN&o-= literal 0 HcmV?d00001 diff --git a/networks/afnonet.py b/networks/afnonet.py new file mode 100644 index 0000000..8d05ec0 --- /dev/null +++ b/networks/afnonet.py @@ -0,0 +1,283 @@ +#reference: https://github.com/NVlabs/AFNO-transformer + +import math +from functools import partial +from collections import OrderedDict +from copy import Error, deepcopy +from re import S +from numpy.lib.arraypad import pad +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +#from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from timm.models.layers import DropPath, trunc_normal_ +import torch.fft +from torch.nn.modules.container import Sequential +from torch.utils.checkpoint import checkpoint_sequential +from einops import rearrange, repeat +from einops.layers.torch import Rearrange +from utils.img_utils import PeriodicPad2d + + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class AFNO2D(nn.Module): + def __init__(self, hidden_size, num_blocks=8, sparsity_threshold=0.01, hard_thresholding_fraction=1, hidden_size_factor=1): + super().__init__() + assert hidden_size % num_blocks == 0, f"hidden_size {hidden_size} should be divisble by num_blocks {num_blocks}" + + self.hidden_size = hidden_size + self.sparsity_threshold = sparsity_threshold + self.num_blocks = num_blocks + self.block_size = self.hidden_size // self.num_blocks + self.hard_thresholding_fraction = hard_thresholding_fraction + self.hidden_size_factor = hidden_size_factor + self.scale = 0.02 + + self.w1 = nn.Parameter(self.scale * torch.randn(2, self.num_blocks, self.block_size, self.block_size * self.hidden_size_factor)) + self.b1 = nn.Parameter(self.scale * torch.randn(2, self.num_blocks, self.block_size * self.hidden_size_factor)) + self.w2 = nn.Parameter(self.scale * torch.randn(2, self.num_blocks, self.block_size * self.hidden_size_factor, self.block_size)) + self.b2 = nn.Parameter(self.scale * torch.randn(2, self.num_blocks, self.block_size)) + + def forward(self, x): + bias = x + + dtype = x.dtype + x = x.float() + B, H, W, C = x.shape + + x = torch.fft.rfft2(x, dim=(1, 2), norm="ortho") + x = x.reshape(B, H, W // 2 + 1, self.num_blocks, self.block_size) + + o1_real = torch.zeros([B, H, W // 2 + 1, self.num_blocks, self.block_size * self.hidden_size_factor], device=x.device) + o1_imag = torch.zeros([B, H, W // 2 + 1, self.num_blocks, self.block_size * self.hidden_size_factor], device=x.device) + o2_real = torch.zeros(x.shape, device=x.device) + o2_imag = torch.zeros(x.shape, device=x.device) + + + total_modes = H // 2 + 1 + kept_modes = int(total_modes * self.hard_thresholding_fraction) + + o1_real[:, total_modes-kept_modes:total_modes+kept_modes, :kept_modes] = F.relu( + torch.einsum('...bi,bio->...bo', x[:, total_modes-kept_modes:total_modes+kept_modes, :kept_modes].real, self.w1[0]) - \ + torch.einsum('...bi,bio->...bo', x[:, total_modes-kept_modes:total_modes+kept_modes, :kept_modes].imag, self.w1[1]) + \ + self.b1[0] + ) + + o1_imag[:, total_modes-kept_modes:total_modes+kept_modes, :kept_modes] = F.relu( + torch.einsum('...bi,bio->...bo', x[:, total_modes-kept_modes:total_modes+kept_modes, :kept_modes].imag, self.w1[0]) + \ + torch.einsum('...bi,bio->...bo', x[:, total_modes-kept_modes:total_modes+kept_modes, :kept_modes].real, self.w1[1]) + \ + self.b1[1] + ) + + o2_real[:, total_modes-kept_modes:total_modes+kept_modes, :kept_modes] = ( + torch.einsum('...bi,bio->...bo', o1_real[:, total_modes-kept_modes:total_modes+kept_modes, :kept_modes], self.w2[0]) - \ + torch.einsum('...bi,bio->...bo', o1_imag[:, total_modes-kept_modes:total_modes+kept_modes, :kept_modes], self.w2[1]) + \ + self.b2[0] + ) + + o2_imag[:, total_modes-kept_modes:total_modes+kept_modes, :kept_modes] = ( + torch.einsum('...bi,bio->...bo', o1_imag[:, total_modes-kept_modes:total_modes+kept_modes, :kept_modes], self.w2[0]) + \ + torch.einsum('...bi,bio->...bo', o1_real[:, total_modes-kept_modes:total_modes+kept_modes, :kept_modes], self.w2[1]) + \ + self.b2[1] + ) + + x = torch.stack([o2_real, o2_imag], dim=-1) + x = F.softshrink(x, lambd=self.sparsity_threshold) + x = torch.view_as_complex(x) + x = x.reshape(B, H, W // 2 + 1, C) + x = torch.fft.irfft2(x, s=(H, W), dim=(1,2), norm="ortho") + x = x.type(dtype) + + return x + bias + + +class Block(nn.Module): + def __init__( + self, + dim, + mlp_ratio=4., + drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + double_skip=True, + num_blocks=8, + sparsity_threshold=0.01, + hard_thresholding_fraction=1.0 + ): + super().__init__() + self.norm1 = norm_layer(dim) + self.filter = AFNO2D(dim, num_blocks, sparsity_threshold, hard_thresholding_fraction) + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + #self.drop_path = nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + self.double_skip = double_skip + + def forward(self, x): + residual = x + x = self.norm1(x) + x = self.filter(x) + + if self.double_skip: + x = x + residual + residual = x + + x = self.norm2(x) + x = self.mlp(x) + x = self.drop_path(x) + x = x + residual + return x + +class PrecipNet(nn.Module): + def __init__(self, params, backbone): + super().__init__() + self.params = params + self.patch_size = (params.patch_size, params.patch_size) + self.in_chans = params.N_in_channels + self.out_chans = params.N_out_channels + self.backbone = backbone + self.ppad = PeriodicPad2d(1) + self.conv = nn.Conv2d(self.out_chans, self.out_chans, kernel_size=3, stride=1, padding=0, bias=True) + self.act = nn.ReLU() + + def forward(self, x): + x = self.backbone(x) + x = self.ppad(x) + x = self.conv(x) + x = self.act(x) + return x + +class AFNONet(nn.Module): + def __init__( + self, + params, + # img_size=(720, 1440), modified by cxt 2023.12.31 + img_size=(192, 288), + patch_size=(16, 16), + in_chans=2, + out_chans=2, + embed_dim=768, + depth=12, + mlp_ratio=4., + drop_rate=0., + drop_path_rate=0., + num_blocks=16, + sparsity_threshold=0.01, + hard_thresholding_fraction=1.0, + ): + super().__init__() + self.params = params + self.img_size = img_size + self.patch_size = (params.patch_size, params.patch_size) + self.in_chans = params.N_in_channels + self.out_chans = params.N_out_channels + self.num_features = self.embed_dim = embed_dim + self.num_blocks = params.num_blocks + norm_layer = partial(nn.LayerNorm, eps=1e-6) + + self.patch_embed = PatchEmbed(img_size=img_size, patch_size=self.patch_size, in_chans=self.in_chans, embed_dim=embed_dim) + num_patches = self.patch_embed.num_patches + + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) + self.pos_drop = nn.Dropout(p=drop_rate) + + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] + + self.h = img_size[0] // self.patch_size[0] + self.w = img_size[1] // self.patch_size[1] + + self.blocks = nn.ModuleList([ + Block(dim=embed_dim, mlp_ratio=mlp_ratio, drop=drop_rate, drop_path=dpr[i], norm_layer=norm_layer, + num_blocks=self.num_blocks, sparsity_threshold=sparsity_threshold, hard_thresholding_fraction=hard_thresholding_fraction) + for i in range(depth)]) + + self.norm = norm_layer(embed_dim) + + self.head = nn.Linear(embed_dim, self.out_chans*self.patch_size[0]*self.patch_size[1], bias=False) + + trunc_normal_(self.pos_embed, std=.02) + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'pos_embed', 'cls_token'} + + def forward_features(self, x): + B = x.shape[0] + x = self.patch_embed(x) + x = x + self.pos_embed + x = self.pos_drop(x) + + x = x.reshape(B, self.h, self.w, self.embed_dim) + for blk in self.blocks: + x = blk(x) + + return x + + def forward(self, x): + x = self.forward_features(x) + x = self.head(x) + x = rearrange( + x, + "b h w (p1 p2 c_out) -> b c_out (h p1) (w p2)", + p1=self.patch_size[0], + p2=self.patch_size[1], + h=self.img_size[0] // self.patch_size[0], + w=self.img_size[1] // self.patch_size[1], + ) + return x + + +class PatchEmbed(nn.Module): + def __init__(self, img_size=(224, 224), patch_size=(16, 16), in_chans=3, embed_dim=768): + super().__init__() + num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) + self.img_size = img_size + self.patch_size = patch_size + self.num_patches = num_patches + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + + def forward(self, x): + B, C, H, W = x.shape + assert H == self.img_size[0] and W == self.img_size[1], f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + x = self.proj(x).flatten(2).transpose(1, 2) + return x + + +if __name__ == "__main__": + model = AFNONet(img_size=(720, 1440), patch_size=(4,4), in_chans=3, out_chans=10) + sample = torch.randn(1, 3, 720, 1440) + result = model(sample) + print(result.shape) + print(torch.norm(result)) + diff --git a/train.py b/train.py new file mode 100644 index 0000000..89e33e7 --- /dev/null +++ b/train.py @@ -0,0 +1,616 @@ +#BSD 3-Clause License +# +#Copyright (c) 2022, FourCastNet authors +#All rights reserved. +# +#Redistribution and use in source and binary forms, with or without +#modification, are permitted provided that the following conditions are met: +# +#1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +#2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +#3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +#THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +#AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +#IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +#DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +#FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +#DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +#SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +#CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +#OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +#OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +#The code was authored by the following people: +# +#Jaideep Pathak - NVIDIA Corporation +#Shashank Subramanian - NERSC, Lawrence Berkeley National Laboratory +#Peter Harrington - NERSC, Lawrence Berkeley National Laboratory +#Sanjeev Raja - NERSC, Lawrence Berkeley National Laboratory +#Ashesh Chattopadhyay - Rice University +#Morteza Mardani - NVIDIA Corporation +#Thorsten Kurth - NVIDIA Corporation +#David Hall - NVIDIA Corporation +#Zongyi Li - California Institute of Technology, NVIDIA Corporation +#Kamyar Azizzadenesheli - Purdue University +#Pedram Hassanzadeh - Rice University +#Karthik Kashinath - NVIDIA Corporation +#Animashree Anandkumar - California Institute of Technology, NVIDIA Corporation + +import os +import time +import numpy as np +import argparse +import h5py +import torch +import cProfile +import re +import torchvision +from torchvision.utils import save_image +import torch.nn as nn +import torch.cuda.amp as amp +import torch.distributed as dist +from torch.nn.parallel import DistributedDataParallel +import logging +from utils import logging_utils +logging_utils.config_logger() +from utils.YParams import YParams +from utils.data_loader_multifiles import get_data_loader +from networks.afnonet import AFNONet, PrecipNet +from utils.img_utils import vis_precip +import wandb +from utils.weighted_acc_rmse import weighted_acc, weighted_rmse, weighted_rmse_torch, unlog_tp_torch +from apex import optimizers +from utils.darcy_loss import LpLoss +import matplotlib.pyplot as plt +from collections import OrderedDict +import pickle +DECORRELATION_TIME = 36 # 9 days +import json +from ruamel.yaml import YAML +from ruamel.yaml.comments import CommentedMap as ruamelDict + +class Trainer(): + def count_parameters(self): + return sum(p.numel() for p in self.model.parameters() if p.requires_grad) + + def __init__(self, params, world_rank): + + self.params = params + self.world_rank = world_rank + self.device = torch.cuda.current_device() if torch.cuda.is_available() else 'cpu' + + if params.log_to_wandb: + wandb.init(config=params, name=params.name, group=params.group, project=params.project) + + logging.info('rank %d, begin data loader init'%world_rank) + self.train_data_loader, self.train_dataset, self.train_sampler = get_data_loader(params, params.train_data_path, dist.is_initialized(), train=True) + self.valid_data_loader, self.valid_dataset = get_data_loader(params, params.valid_data_path, dist.is_initialized(), train=False) + self.loss_obj = LpLoss() + logging.info('rank %d, data loader initialized'%world_rank) + + params.crop_size_x = self.valid_dataset.crop_size_x + params.crop_size_y = self.valid_dataset.crop_size_y + params.img_shape_x = self.valid_dataset.img_shape_x + params.img_shape_y = self.valid_dataset.img_shape_y + + # precip models + self.precip = True if "precip" in params else False + + if self.precip: + if 'model_wind_path' not in params: + raise Exception("no backbone model weights specified") + # load a wind model + # the wind model has out channels = in channels + out_channels = np.array(params['in_channels']) + params['N_out_channels'] = len(out_channels) + + if params.nettype_wind == 'afno': + self.model_wind = AFNONet(params).to(self.device) + else: + raise Exception("not implemented") + + if dist.is_initialized(): + self.model_wind = DistributedDataParallel(self.model_wind, + device_ids=[params.local_rank], + output_device=[params.local_rank],find_unused_parameters=True) + self.load_model_wind(params.model_wind_path) + self.switch_off_grad(self.model_wind) # no backprop through the wind model + + + # reset out_channels for precip models + if self.precip: + params['N_out_channels'] = len(params['out_channels']) + + if params.nettype == 'afno': + self.model = AFNONet(params).to(self.device) + else: + raise Exception("not implemented") + + # precip model + if self.precip: + self.model = PrecipNet(params, backbone=self.model).to(self.device) + + if self.params.enable_nhwc: + # NHWC: Convert model to channels_last memory format + self.model = self.model.to(memory_format=torch.channels_last) + + if params.log_to_wandb: + wandb.watch(self.model) + + if params.optimizer_type == 'FusedAdam': + self.optimizer = optimizers.FusedAdam(self.model.parameters(), lr = params.lr) + else: + self.optimizer = torch.optim.Adam(self.model.parameters(), lr = params.lr) + + if params.enable_amp == True: + self.gscaler = amp.GradScaler() + + if dist.is_initialized(): + self.model = DistributedDataParallel(self.model, + device_ids=[params.local_rank], + output_device=[params.local_rank],find_unused_parameters=True) + + self.iters = 0 + self.startEpoch = 0 + if params.resuming: + logging.info("Loading checkpoint %s"%params.checkpoint_path) + self.restore_checkpoint(params.checkpoint_path) + if params.two_step_training: + if params.resuming == False and params.pretrained == True: + logging.info("Starting from pretrained one-step afno model at %s"%params.pretrained_ckpt_path) + self.restore_checkpoint(params.pretrained_ckpt_path) + self.iters = 0 + self.startEpoch = 0 + #logging.info("Pretrained checkpoint was trained for %d epochs"%self.startEpoch) + #logging.info("Adding %d epochs specified in config file for refining pretrained model"%self.params.max_epochs) + #self.params.max_epochs += self.startEpoch + + + self.epoch = self.startEpoch + + if params.scheduler == 'ReduceLROnPlateau': + self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, factor=0.2, patience=5, mode='min') + elif params.scheduler == 'CosineAnnealingLR': + self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer, T_max=params.max_epochs, last_epoch=self.startEpoch-1) + else: + self.scheduler = None + + '''if params.log_to_screen: + logging.info(self.model)''' + if params.log_to_screen: + logging.info("Number of trainable model parameters: {}".format(self.count_parameters())) + + def switch_off_grad(self, model): + for param in model.parameters(): + param.requires_grad = False + + def train(self): + if self.params.log_to_screen: + logging.info("Starting Training Loop...") + + best_valid_loss = 1.e6 + for epoch in range(self.startEpoch, self.params.max_epochs): + if dist.is_initialized(): + self.train_sampler.set_epoch(epoch) +# self.valid_sampler.set_epoch(epoch) + + start = time.time() + tr_time, data_time, train_logs = self.train_one_epoch() + valid_time, valid_logs = self.validate_one_epoch() + if epoch==self.params.max_epochs-1 and self.params.prediction_type == 'direct': + valid_weighted_rmse = self.validate_final() + + + + if self.params.scheduler == 'ReduceLROnPlateau': + self.scheduler.step(valid_logs['valid_loss']) + elif self.params.scheduler == 'CosineAnnealingLR': + self.scheduler.step() + if self.epoch >= self.params.max_epochs: + logging.info("Terminating training after reaching params.max_epochs while LR scheduler is set to CosineAnnealingLR") + exit() + + if self.params.log_to_wandb: + for pg in self.optimizer.param_groups: + lr = pg['lr'] + wandb.log({'lr': lr}) + + if self.world_rank == 0: + if self.params.save_checkpoint: + #checkpoint at the end of every epoch + self.save_checkpoint(self.params.checkpoint_path) + if valid_logs['valid_loss'] <= best_valid_loss: + #logging.info('Val loss improved from {} to {}'.format(best_valid_loss, valid_logs['valid_loss'])) + self.save_checkpoint(self.params.best_checkpoint_path) + best_valid_loss = valid_logs['valid_loss'] + + if self.params.log_to_screen: + logging.info('Time taken for epoch {} is {} sec'.format(epoch + 1, time.time()-start)) + #logging.info('train data time={}, train step time={}, valid step time={}'.format(data_time, tr_time, valid_time)) + logging.info('Train loss: {}. Valid loss: {}'.format(train_logs['loss'], valid_logs['valid_loss'])) +# if epoch==self.params.max_epochs-1 and self.params.prediction_type == 'direct': +# logging.info('Final Valid RMSE: Z500- {}. T850- {}, 2m_T- {}'.format(valid_weighted_rmse[0], valid_weighted_rmse[1], valid_weighted_rmse[2])) + + + + def train_one_epoch(self): + self.epoch += 1 + tr_time = 0 + data_time = 0 + self.model.train() + + for i, data in enumerate(self.train_data_loader, 0): + self.iters += 1 + # adjust_LR(optimizer, params, iters) + data_start = time.time() + inp, tar = map(lambda x: x.to(self.device, dtype = torch.float), data) + if self.params.orography and self.params.two_step_training: + orog = inp[:,-2:-1] + + + if self.params.enable_nhwc: + inp = inp.to(memory_format=torch.channels_last) + tar = tar.to(memory_format=torch.channels_last) + + + if 'residual_field' in self.params.target: + tar -= inp[:, 0:tar.size()[1]] + data_time += time.time() - data_start + + tr_start = time.time() + + self.model.zero_grad() + if self.params.two_step_training: + with amp.autocast(self.params.enable_amp): + gen_step_one = self.model(inp).to(self.device, dtype = torch.float) + loss_step_one = self.loss_obj(gen_step_one, tar[:,0:self.params.N_out_channels]) + if self.params.orography: + gen_step_two = self.model(torch.cat( (gen_step_one, orog), axis = 1) ).to(self.device, dtype = torch.float) + else: + gen_step_two = self.model(gen_step_one).to(self.device, dtype = torch.float) + loss_step_two = self.loss_obj(gen_step_two, tar[:,self.params.N_out_channels:2*self.params.N_out_channels]) + loss = loss_step_one + loss_step_two + else: + with amp.autocast(self.params.enable_amp): + if self.precip: # use a wind model to predict 17(+n) channels at t+dt + with torch.no_grad(): + inp = self.model_wind(inp).to(self.device, dtype = torch.float) + gen = self.model(inp.detach()).to(self.device, dtype = torch.float) + else: + gen = self.model(inp).to(self.device, dtype = torch.float) + loss = self.loss_obj(gen, tar) + + if self.params.enable_amp: + self.gscaler.scale(loss).backward() + self.gscaler.step(self.optimizer) + else: + loss.backward() + self.optimizer.step() + + if self.params.enable_amp: + self.gscaler.update() + + tr_time += time.time() - tr_start + + try: + logs = {'loss': loss, 'loss_step_one': loss_step_one, 'loss_step_two': loss_step_two} + except: + logs = {'loss': loss} + + if dist.is_initialized(): + for key in sorted(logs.keys()): + dist.all_reduce(logs[key].detach()) + logs[key] = float(logs[key]/dist.get_world_size()) + + if self.params.log_to_wandb: + wandb.log(logs, step=self.epoch) + + return tr_time, data_time, logs + + def validate_one_epoch(self): + self.model.eval() + n_valid_batches = 20 #do validation on first 20 images, just for LR scheduler + if self.params.normalization == 'minmax': + raise Exception("minmax normalization not supported") + elif self.params.normalization == 'zscore': + mult = torch.as_tensor(np.load(self.params.global_stds_path)[0, self.params.out_channels, 0, 0]).to(self.device) + + valid_buff = torch.zeros((3), dtype=torch.float32, device=self.device) + valid_loss = valid_buff[0].view(-1) + valid_l1 = valid_buff[1].view(-1) + valid_steps = valid_buff[2].view(-1) + valid_weighted_rmse = torch.zeros((self.params.N_out_channels), dtype=torch.float32, device=self.device) + valid_weighted_acc = torch.zeros((self.params.N_out_channels), dtype=torch.float32, device=self.device) + + valid_start = time.time() + + sample_idx = np.random.randint(len(self.valid_data_loader)) + with torch.no_grad(): + for i, data in enumerate(self.valid_data_loader, 0): + if (not self.precip) and i>=n_valid_batches: + break + inp, tar = map(lambda x: x.to(self.device, dtype = torch.float), data) + if self.params.orography and self.params.two_step_training: + orog = inp[:,-2:-1] + + if self.params.two_step_training: + gen_step_one = self.model(inp).to(self.device, dtype = torch.float) + loss_step_one = self.loss_obj(gen_step_one, tar[:,0:self.params.N_out_channels]) + + if self.params.orography: + gen_step_two = self.model(torch.cat( (gen_step_one, orog), axis = 1) ).to(self.device, dtype = torch.float) + else: + gen_step_two = self.model(gen_step_one).to(self.device, dtype = torch.float) + + loss_step_two = self.loss_obj(gen_step_two, tar[:,self.params.N_out_channels:2*self.params.N_out_channels]) + valid_loss += loss_step_one + loss_step_two + valid_l1 += nn.functional.l1_loss(gen_step_one, tar[:,0:self.params.N_out_channels]) + else: + if self.precip: + with torch.no_grad(): + inp = self.model_wind(inp).to(self.device, dtype = torch.float) + gen = self.model(inp.detach()) + else: + gen = self.model(inp).to(self.device, dtype = torch.float) + valid_loss += self.loss_obj(gen, tar) + valid_l1 += nn.functional.l1_loss(gen, tar) + + valid_steps += 1. + # save fields for vis before log norm + if (i == sample_idx) and (self.precip and self.params.log_to_wandb): + fields = [gen[0,0].detach().cpu().numpy(), tar[0,0].detach().cpu().numpy()] + + if self.precip: + gen = unlog_tp_torch(gen, self.params.precip_eps) + tar = unlog_tp_torch(tar, self.params.precip_eps) + + #direct prediction weighted rmse + if self.params.two_step_training: + if 'residual_field' in self.params.target: + valid_weighted_rmse += weighted_rmse_torch((gen_step_one + inp), (tar[:,0:self.params.N_out_channels] + inp)) + else: + valid_weighted_rmse += weighted_rmse_torch(gen_step_one, tar[:,0:self.params.N_out_channels]) + else: + if 'residual_field' in self.params.target: + valid_weighted_rmse += weighted_rmse_torch((gen + inp), (tar + inp)) + else: + valid_weighted_rmse += weighted_rmse_torch(gen, tar) + + + if not self.precip: + try: + os.mkdir(params['experiment_dir'] + "/" + str(i)) + except: + pass + #save first channel of image + if self.params.two_step_training: + save_image(torch.cat((gen_step_one[0,0], torch.zeros((self.valid_dataset.img_shape_x, 4)).to(self.device, dtype = torch.float), tar[0,0]), axis = 1), params['experiment_dir'] + "/" + str(i) + "/" + str(self.epoch) + ".png") + else: + save_image(torch.cat((gen[0,0], torch.zeros((self.valid_dataset.img_shape_x, 4)).to(self.device, dtype = torch.float), tar[0,0]), axis = 1), params['experiment_dir'] + "/" + str(i) + "/" + str(self.epoch) + ".png") + + + if dist.is_initialized(): + dist.all_reduce(valid_buff) + dist.all_reduce(valid_weighted_rmse) + + # divide by number of steps + valid_buff[0:2] = valid_buff[0:2] / valid_buff[2] + valid_weighted_rmse = valid_weighted_rmse / valid_buff[2] + if not self.precip: + valid_weighted_rmse *= mult + + # download buffers + valid_buff_cpu = valid_buff.detach().cpu().numpy() + valid_weighted_rmse_cpu = valid_weighted_rmse.detach().cpu().numpy() + + valid_time = time.time() - valid_start + valid_weighted_rmse = mult*torch.mean(valid_weighted_rmse, axis = 0) + if self.precip: + logs = {'valid_l1': valid_buff_cpu[1], 'valid_loss': valid_buff_cpu[0], 'valid_rmse_tp': valid_weighted_rmse_cpu[0]} + else: + try: + logs = {'valid_l1': valid_buff_cpu[1], 'valid_loss': valid_buff_cpu[0], 'valid_rmse_u10': valid_weighted_rmse_cpu[0], 'valid_rmse_v10': valid_weighted_rmse_cpu[1]} + except: + logs = {'valid_l1': valid_buff_cpu[1], 'valid_loss': valid_buff_cpu[0], 'valid_rmse_u10': valid_weighted_rmse_cpu[0]}#, 'valid_rmse_v10': valid_weighted_rmse[1]} + + if self.params.log_to_wandb: + if self.precip: + fig = vis_precip(fields) + logs['vis'] = wandb.Image(fig) + plt.close(fig) + wandb.log(logs, step=self.epoch) + + return valid_time, logs + + def validate_final(self): + self.model.eval() + n_valid_batches = int(self.valid_dataset.n_patches_total/self.valid_dataset.n_patches) #validate on whole dataset + valid_weighted_rmse = torch.zeros(n_valid_batches, self.params.N_out_channels) + if self.params.normalization == 'minmax': + raise Exception("minmax normalization not supported") + elif self.params.normalization == 'zscore': + mult = torch.as_tensor(np.load(self.params.global_stds_path)[0, self.params.out_channels, 0, 0]).to(self.device) + + with torch.no_grad(): + for i, data in enumerate(self.valid_data_loader): + if i>100: + break + inp, tar = map(lambda x: x.to(self.device, dtype = torch.float), data) + if self.params.orography and self.params.two_step_training: + orog = inp[:,-2:-1] + if 'residual_field' in self.params.target: + tar -= inp[:, 0:tar.size()[1]] + + if self.params.two_step_training: + gen_step_one = self.model(inp).to(self.device, dtype = torch.float) + loss_step_one = self.loss_obj(gen_step_one, tar[:,0:self.params.N_out_channels]) + + if self.params.orography: + gen_step_two = self.model(torch.cat( (gen_step_one, orog), axis = 1) ).to(self.device, dtype = torch.float) + else: + gen_step_two = self.model(gen_step_one).to(self.device, dtype = torch.float) + + loss_step_two = self.loss_obj(gen_step_two, tar[:,self.params.N_out_channels:2*self.params.N_out_channels]) + valid_loss[i] = loss_step_one + loss_step_two + valid_l1[i] = nn.functional.l1_loss(gen_step_one, tar[:,0:self.params.N_out_channels]) + else: + gen = self.model(inp) + valid_loss[i] += self.loss_obj(gen, tar) + valid_l1[i] += nn.functional.l1_loss(gen, tar) + + if self.params.two_step_training: + for c in range(self.params.N_out_channels): + if 'residual_field' in self.params.target: + valid_weighted_rmse[i, c] = weighted_rmse_torch((gen_step_one[0,c] + inp[0,c]), (tar[0,c]+inp[0,c]), self.device) + else: + valid_weighted_rmse[i, c] = weighted_rmse_torch(gen_step_one[0,c], tar[0,c], self.device) + else: + for c in range(self.params.N_out_channels): + if 'residual_field' in self.params.target: + valid_weighted_rmse[i, c] = weighted_rmse_torch((gen[0,c] + inp[0,c]), (tar[0,c]+inp[0,c]), self.device) + else: + valid_weighted_rmse[i, c] = weighted_rmse_torch(gen[0,c], tar[0,c], self.device) + + #un-normalize + valid_weighted_rmse = mult*torch.mean(valid_weighted_rmse[0:100], axis = 0).to(self.device) + + return valid_weighted_rmse + + + def load_model_wind(self, model_path): + if self.params.log_to_screen: + logging.info('Loading the wind model weights from {}'.format(model_path)) + checkpoint = torch.load(model_path, map_location='cuda:{}'.format(self.params.local_rank)) + if dist.is_initialized(): + self.model_wind.load_state_dict(checkpoint['model_state']) + else: + new_model_state = OrderedDict() + model_key = 'model_state' if 'model_state' in checkpoint else 'state_dict' + for key in checkpoint[model_key].keys(): + if 'module.' in key: # model was stored using ddp which prepends module + name = str(key[7:]) + new_model_state[name] = checkpoint[model_key][key] + else: + new_model_state[key] = checkpoint[model_key][key] + self.model_wind.load_state_dict(new_model_state) + self.model_wind.eval() + + def save_checkpoint(self, checkpoint_path, model=None): + """ We intentionally require a checkpoint_dir to be passed + in order to allow Ray Tune to use this function """ + + if not model: + model = self.model + + torch.save({'iters': self.iters, 'epoch': self.epoch, 'model_state': model.state_dict(), + 'optimizer_state_dict': self.optimizer.state_dict()}, checkpoint_path) + + def restore_checkpoint(self, checkpoint_path): + """ We intentionally require a checkpoint_dir to be passed + in order to allow Ray Tune to use this function """ + checkpoint = torch.load(checkpoint_path, map_location='cuda:{}'.format(self.params.local_rank)) + try: + self.model.load_state_dict(checkpoint['model_state']) + except: + new_state_dict = OrderedDict() + for key, val in checkpoint['model_state'].items(): + name = key[7:] + new_state_dict[name] = val + self.model.load_state_dict(new_state_dict) + self.iters = checkpoint['iters'] + self.startEpoch = checkpoint['epoch'] + if self.params.resuming: #restore checkpoint is used for finetuning as well as resuming. If finetuning (i.e., not resuming), restore checkpoint does not load optimizer state, instead uses config specified lr. + self.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument("--run_num", default='00', type=str) + parser.add_argument("--yaml_config", default='./config/AFNO.bak.yaml', type=str) + parser.add_argument("--config", default='default', type=str) + parser.add_argument("--enable_amp", action='store_true') + parser.add_argument("--epsilon_factor", default = 0, type = float) + + args = parser.parse_args() + + params = YParams(os.path.abspath(args.yaml_config), args.config) + params['epsilon_factor'] = args.epsilon_factor + + params['world_size'] = 1 + if 'WORLD_SIZE' in os.environ: + params['world_size'] = int(os.environ['WORLD_SIZE']) + + world_rank = 0 + local_rank = 0 + if params['world_size'] > 1: + dist.init_process_group(backend='nccl', + init_method='env://') + local_rank = int(os.environ["LOCAL_RANK"]) + args.gpu = local_rank + world_rank = dist.get_rank() + params['global_batch_size'] = params.batch_size + params['batch_size'] = int(params.batch_size//params['world_size']) + + torch.cuda.set_device(local_rank) + torch.backends.cudnn.benchmark = True + + # Set up directory + expDir = os.path.join(params.exp_dir, args.config, str(args.run_num)) + if world_rank==0: + if not os.path.isdir(expDir): + os.makedirs(expDir) + os.makedirs(os.path.join(expDir, 'training_checkpoints/')) + + params['experiment_dir'] = os.path.abspath(expDir) + params['checkpoint_path'] = os.path.join(expDir, 'training_checkpoints/ckpt.tar') + params['best_checkpoint_path'] = os.path.join(expDir, 'training_checkpoints/best_ckpt.tar') + + # Do not comment this line out please: + args.resuming = True if os.path.isfile(params.checkpoint_path) else False + + params['resuming'] = args.resuming + params['local_rank'] = local_rank + params['enable_amp'] = args.enable_amp + + # this will be the wandb name +# params['name'] = args.config + '_' + str(args.run_num) +# params['group'] = "era5_wind" + args.config + params['name'] = args.config + '_' + str(args.run_num) + params['group'] = "era5_precip" + args.config + params['project'] = "ERA5_precip" + params['entity'] = "flowgan" + if world_rank==0: + logging_utils.log_to_file(logger_name=None, log_filename=os.path.join(expDir, 'out.log')) + logging_utils.log_versions() + params.log() + + params['log_to_wandb'] = (world_rank==0) and params['log_to_wandb'] + params['log_to_screen'] = (world_rank==0) and params['log_to_screen'] + + params['in_channels'] = np.array(params['in_channels']) + params['out_channels'] = np.array(params['out_channels']) + if params.orography: + params['N_in_channels'] = len(params['in_channels']) +1 + else: + params['N_in_channels'] = len(params['in_channels']) + + params['N_out_channels'] = len(params['out_channels']) + + if world_rank == 0: + hparams = ruamelDict() + yaml = YAML() + for key, value in params.params.items(): + hparams[str(key)] = str(value) + with open(os.path.join(expDir, 'hyperparams.yaml'), 'w') as hpfile: + yaml.dump(hparams, hpfile ) + + trainer = Trainer(params, world_rank) + trainer.train() + logging.info('DONE ---- rank %d'%world_rank) -- GitLab