From 0711ce89763124f953b655224c99f466a2f54894 Mon Sep 17 00:00:00 2001
From: xuetao chen <cxt@mail.ustc.edu.cn>
Date: Fri, 19 Jan 2024 11:24:13 +0800
Subject: [PATCH] initial

---
 .idea/.gitignore                              |   8 +
 .idea/FourCastNetTEC.iml                      |   8 +
 .../inspectionProfiles/profiles_settings.xml  |   6 +
 .idea/misc.xml                                |   4 +
 .idea/modules.xml                             |   8 +
 .idea/vcs.xml                                 |   6 +
 config/AFNO.bak.yaml                          | 131 ++++
 config/AFNO.yaml                              | 195 ++++++
 networks/__pycache__/afnonet.cpython-38.pyc   | Bin 0 -> 9351 bytes
 networks/afnonet.py                           | 283 ++++++++
 train.py                                      | 616 ++++++++++++++++++
 11 files changed, 1265 insertions(+)
 create mode 100644 .idea/.gitignore
 create mode 100644 .idea/FourCastNetTEC.iml
 create mode 100644 .idea/inspectionProfiles/profiles_settings.xml
 create mode 100644 .idea/misc.xml
 create mode 100644 .idea/modules.xml
 create mode 100644 .idea/vcs.xml
 create mode 100644 config/AFNO.bak.yaml
 create mode 100644 config/AFNO.yaml
 create mode 100644 networks/__pycache__/afnonet.cpython-38.pyc
 create mode 100644 networks/afnonet.py
 create mode 100644 train.py

diff --git a/.idea/.gitignore b/.idea/.gitignore
new file mode 100644
index 0000000..13566b8
--- /dev/null
+++ b/.idea/.gitignore
@@ -0,0 +1,8 @@
+# Default ignored files
+/shelf/
+/workspace.xml
+# Editor-based HTTP Client requests
+/httpRequests/
+# Datasource local storage ignored files
+/dataSources/
+/dataSources.local.xml
diff --git a/.idea/FourCastNetTEC.iml b/.idea/FourCastNetTEC.iml
new file mode 100644
index 0000000..bc6c1cf
--- /dev/null
+++ b/.idea/FourCastNetTEC.iml
@@ -0,0 +1,8 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<module type="PYTHON_MODULE" version="4">
+  <component name="NewModuleRootManager">
+    <content url="file://$MODULE_DIR$" />
+    <orderEntry type="jdk" jdkName="castnet (3)" jdkType="Python SDK" />
+    <orderEntry type="sourceFolder" forTests="false" />
+  </component>
+</module>
\ No newline at end of file
diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml
new file mode 100644
index 0000000..105ce2d
--- /dev/null
+++ b/.idea/inspectionProfiles/profiles_settings.xml
@@ -0,0 +1,6 @@
+<component name="InspectionProjectProfileManager">
+  <settings>
+    <option name="USE_PROJECT_PROFILE" value="false" />
+    <version value="1.0" />
+  </settings>
+</component>
\ No newline at end of file
diff --git a/.idea/misc.xml b/.idea/misc.xml
new file mode 100644
index 0000000..466135d
--- /dev/null
+++ b/.idea/misc.xml
@@ -0,0 +1,4 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<project version="4">
+  <component name="ProjectRootManager" version="2" project-jdk-name="castnet (3)" project-jdk-type="Python SDK" />
+</project>
\ No newline at end of file
diff --git a/.idea/modules.xml b/.idea/modules.xml
new file mode 100644
index 0000000..dffda97
--- /dev/null
+++ b/.idea/modules.xml
@@ -0,0 +1,8 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<project version="4">
+  <component name="ProjectModuleManager">
+    <modules>
+      <module fileurl="file://$PROJECT_DIR$/.idea/FourCastNetTEC.iml" filepath="$PROJECT_DIR$/.idea/FourCastNetTEC.iml" />
+    </modules>
+  </component>
+</project>
\ No newline at end of file
diff --git a/.idea/vcs.xml b/.idea/vcs.xml
new file mode 100644
index 0000000..35eb1dd
--- /dev/null
+++ b/.idea/vcs.xml
@@ -0,0 +1,6 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<project version="4">
+  <component name="VcsDirectoryMappings">
+    <mapping directory="" vcs="Git" />
+  </component>
+</project>
\ No newline at end of file
diff --git a/config/AFNO.bak.yaml b/config/AFNO.bak.yaml
new file mode 100644
index 0000000..72c8fb9
--- /dev/null
+++ b/config/AFNO.bak.yaml
@@ -0,0 +1,131 @@
+### base config ###
+full_field: &FULL_FIELD
+  loss: 'l2'
+  lr: 1E-3
+  scheduler: 'ReduceLROnPlateau'
+  num_data_workers: 4
+  dt: 1 # how many timesteps ahead the model will predict
+  n_history: 0 #how many previous timesteps to consider
+  prediction_type: 'iterative'
+  prediction_length: 41 #applicable only if prediction_type == 'iterative'
+  n_initial_conditions: 5 #applicable only if prediction_type == 'iterative'
+  ics_type: "default"
+  save_raw_forecasts: !!bool True
+  save_channel: !!bool False
+  masked_acc: !!bool False
+  maskpath: None
+  perturb: !!bool False
+  add_grid: !!bool False
+  N_grid_channels: 0
+  gridtype: 'sinusoidal' #options 'sinusoidal' or 'linear'
+  roll: !!bool False
+  max_epochs: 50
+  batch_size: 64
+
+  #afno hyperparams
+  num_blocks: 8
+  nettype: 'afno'
+  patch_size: 8
+  width: 56
+  modes: 32
+  #options default, residual
+  target: 'default' 
+  in_channels: [0,1]
+  out_channels: [0,1] #must be same as in_channels if prediction_type == 'iterative'
+  normalization: 'zscore' #options zscore (minmax not supported) 
+  train_data_path: '/pscratch/sd/j/jpathak/wind/train'
+  valid_data_path: '/pscratch/sd/j/jpathak/wind/test'
+  inf_data_path: '/pscratch/sd/j/jpathak/wind/out_of_sample' # test set path for inference
+  exp_dir: '/pscratch/sd/j/jpathak/ERA5_expts_gtc/wind'
+  time_means_path:   '/pscratch/sd/j/jpathak/wind/time_means.npy'
+  global_means_path: '/pscratch/sd/j/jpathak/wind/global_means.npy'
+  global_stds_path:  '/pscratch/sd/j/jpathak/wind/global_stds.npy'
+
+  orography: !!bool False
+  orography_path: None
+
+  log_to_screen: !!bool True
+  log_to_wandb: !!bool True
+  save_checkpoint: !!bool True
+
+  enable_nhwc: !!bool False
+  optimizer_type: 'FusedAdam'
+  crop_size_x: None
+  crop_size_y: None
+
+  two_step_training: !!bool False
+  plot_animations: !!bool False
+
+  add_noise: !!bool False
+  noise_std: 0
+
+afno_backbone: &backbone
+  <<: *FULL_FIELD
+  log_to_wandb: !!bool True
+  lr: 5E-4
+  batch_size: 2
+  max_epochs: 150
+  scheduler: 'CosineAnnealingLR'
+  in_channels: [0, 1 ,2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19]
+  out_channels: [0, 1 ,2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19]
+  orography: !!bool False
+  orography_path: None
+  exp_dir: '/mnt/data/work/FourCastNet/results/nowcasting'
+  train_data_path: '/mnt/data/work/FourCastNet/data/train'
+  valid_data_path: '/mnt/data/work/FourCastNet/data/test'
+  inf_data_path: '/mnt/data/work/FourCastNet/data/out_of_sample'
+  time_means_path: '/mnt/data/work/FourCastNet/data/time_means.npy'
+  global_means_path: '/mnt/data/work/FourCastNet/data/global_means.npy'
+  global_stds_path: '/mnt/data/work/FourCastNet/data/global_stds.npy'
+
+afno_backbone_orography: &backbone_orography 
+  <<: *backbone
+  orography: !!bool True
+  orography_path: '/pscratch/sd/s/shas1693/data/era5/static/orography.h5'
+
+afno_backbone_finetune: 
+  <<: *backbone
+  lr: 1E-4
+  batch_size: 64
+  log_to_wandb: !!bool True
+  max_epochs: 50
+  pretrained: !!bool True
+  two_step_training: !!bool True
+  pretrained_ckpt_path: '/pscratch/sd/s/shas1693/results/era5_wind/afno_backbone/0/training_checkpoints/best_ckpt.tar'
+
+perturbations:
+  <<: *backbone
+  lr: 1E-4
+  batch_size: 64
+  max_epochs: 50
+  pretrained: !!bool True
+  two_step_training: !!bool True
+  pretrained_ckpt_path: '/pscratch/sd/j/jpathak/ERA5_expts_gtc/wind/afno_20ch_bs_64_lr5em4_blk_8_patch_8_cosine_sched/1/training_checkpoints/best_ckpt.tar'
+  prediction_length: 24
+  ics_type: "datetime"
+  n_perturbations: 100 
+  save_channel: !bool True
+  save_idx: 4
+  save_raw_forecasts: !!bool False
+  date_strings: ["2018-01-01 00:00:00"] 
+  inference_file_tag: " "
+  valid_data_path: "/pscratch/sd/j/jpathak/ "
+  perturb: !!bool True
+  n_level: 0.3
+
+### PRECIP ###
+precip: &precip
+  <<: *backbone
+  in_channels: [0, 1 ,2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19]
+  out_channels: [0]
+  nettype: 'afno'
+  nettype_wind: 'afno'
+  log_to_wandb: !!bool True
+  lr: 2.5E-4
+  batch_size: 64
+  max_epochs: 25
+  precip: '/pscratch/sd/p/pharring/ERA5/precip/total_precipitation'
+  time_means_path_tp: '/pscratch/sd/p/pharring/ERA5/precip/total_precipitation/time_means.npy'
+  model_wind_path: '/pscratch/sd/s/shas1693/results/era5_wind/afno_backbone_finetune/0/training_checkpoints/best_ckpt.tar'
+  precip_eps: !!float 1e-5
+
diff --git a/config/AFNO.yaml b/config/AFNO.yaml
new file mode 100644
index 0000000..765146c
--- /dev/null
+++ b/config/AFNO.yaml
@@ -0,0 +1,195 @@
+### base config ###
+full_field: &FULL_FIELD
+  loss: 'l2'
+  lr: 1E-3
+  scheduler: 'ReduceLROnPlateau'
+  num_data_workers: 4
+  dt: 1 # how many timesteps ahead the model will predict
+  n_history: 0 #how many previous timesteps to consider
+  prediction_type: 'iterative'
+  prediction_length: 41 #applicable only if prediction_type == 'iterative'
+  n_initial_conditions: 5 #applicable only if prediction_type == 'iterative'
+  ics_type: "default"
+  save_raw_forecasts: !!bool True
+  save_channel: !!bool False
+  masked_acc: !!bool False
+  maskpath: None
+  perturb: !!bool False
+  add_grid: !!bool False
+  N_grid_channels: 0
+  gridtype: 'sinusoidal' #options 'sinusoidal' or 'linear'
+  roll: !!bool False
+  max_epochs: 50
+  batch_size: 64
+
+  #afno hyperparams
+  num_blocks: 8
+  nettype: 'afno'
+  patch_size: 8
+  width: 56
+  modes: 32
+  #options default, residual
+  target: 'default' 
+  in_channels: [0,1]
+  out_channels: [0,1] #must be same as in_channels if prediction_type == 'iterative'
+  normalization: 'zscore' #options zscore (minmax not supported) 
+  train_data_path: '/pscratch/sd/j/jpathak/wind/train'
+  valid_data_path: '/pscratch/sd/j/jpathak/wind/test'
+  inf_data_path: '/pscratch/sd/j/jpathak/wind/out_of_sample' # test set path for inference
+  exp_dir: '/pscratch/sd/j/jpathak/ERA5_expts_gtc/wind'
+  time_means_path:   '/pscratch/sd/j/jpathak/wind/time_means.npy'
+  global_means_path: '/pscratch/sd/j/jpathak/wind/global_means.npy'
+  global_stds_path:  '/pscratch/sd/j/jpathak/wind/global_stds.npy'
+
+  orography: !!bool False
+  orography_path: None
+
+  log_to_screen: !!bool True
+  log_to_wandb: !!bool True
+  save_checkpoint: !!bool True
+
+  enable_nhwc: !!bool False
+  optimizer_type: 'FusedAdam'
+  crop_size_x: None
+  crop_size_y: None
+
+  two_step_training: !!bool False
+  plot_animations: !!bool False
+
+  add_noise: !!bool False
+  noise_std: 0
+
+
+afno_backbone_cxtwin:
+  <<: *FULL_FIELD
+  log_to_wandb: !!bool True
+  lr: 5E-4
+  batch_size: 16
+  max_epochs: 1510
+  scheduler: 'CosineAnnealingLR'
+  in_channels: [0, 1 ,2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29]
+  out_channels: [0, 1 ,2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29]
+#  in_channels: [0, 1, 2, 3, 4, 5]
+#  out_channels: [0, 1, 2, 3, 4, 5]
+  orography: !!bool False
+  orography_path: None
+  exp_dir: '/home/cxt/work/fourcastnet/FourCastNetwithobswin/results/ljkj'
+  train_data_path: '/home/cxt/work/fourcastnet/FourCastNetwithobswin/data_ljkj/train'
+  valid_data_path: '/home/cxt/work/fourcastnet/FourCastNetwithobswin/data_ljkj/test'
+  inf_data_path:   '/home/cxt/work/fourcastnet/FourCastNetwithobswin/data_ljkj/out_of_sample'
+  time_means_path:   '/home/cxt/work/fourcastnet/FourCastNetwithobswin/data_ljkj/time_means.npy'
+  global_means_path: '/home/cxt/work/fourcastnet/FourCastNetwithobswin/data_ljkj/global_means.npy'
+  global_stds_path:  '/home/cxt/work/fourcastnet/FourCastNetwithobswin/data_ljkj/global_stds.npy'
+
+afno_backbone_cxt: &backbone_cxt
+  <<: *FULL_FIELD
+  log_to_wandb: !!bool True
+  lr: 5E-4
+  batch_size: 2
+  max_epochs: 150
+  scheduler: 'CosineAnnealingLR'
+#  in_channels: [0, 1 ,2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19]
+#  out_channels: [0, 1 ,2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19]
+  in_channels: [0, 1, 2, 3, 4, 5]
+  out_channels: [0, 1, 2, 3, 4, 5]
+  orography: !!bool False
+  orography_path: None
+  exp_dir: '/mnt/data/work/FourCastNet_linux/results/era5_wind'
+  train_data_path: '/mnt/data/work/FourCastNet_linux/data_nowcasting/train'
+  valid_data_path: '/mnt/data/work/FourCastNet_linuxdata_nowcasting/test'
+  inf_data_path:   '/mnt/data/work/FourCastNet_linux/data_nowcasting/out_of_sample'
+  time_means_path:   '/mnt/data/work/FourCastNet_linux/data_nowcasting/time_means.npy'
+  global_means_path: '/mnt/data/work/FourCastNet_linux/data_nowcasting/global_means.npy'
+  global_stds_path:  '/mnt/data/work/FourCastNet_linux/data_nowcasting/global_stds.npy'
+
+afno_backbone_nowcasting:
+  <<: *FULL_FIELD
+  log_to_wandb: !!bool True
+  lr: 5E-4
+  batch_size: 2
+  max_epochs: 150
+  scheduler: 'CosineAnnealingLR'
+#  in_channels: [0, 1 ,2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19]
+#  out_channels: [0, 1 ,2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19]
+  in_channels: [0]
+  out_channels: [0]
+  orography: !!bool False
+  orography_path: None
+  exp_dir: '/mnt/data/work/FourCastNet/results/nowcasting'
+  train_data_path: '/mnt/data/work/FourCastNet/data_nowcasting/train'
+  valid_data_path: '/mnt/data/work/FourCastNet/data_nowcasting/test'
+  inf_data_path:   '/mnt/data/work/FourCastNet/data_nowcasting/out_of_sample'
+  time_means_path:   '/mnt/data/work/FourCastNet/data_nowcasting/time_means.npy'
+  global_means_path: '/mnt/data/work/FourCastNet/data_nowcasting/global_means.npy'
+  global_stds_path:  '/mnt/data/work/FourCastNet/data_nowcasting/global_stds.npy'
+
+afno_backbone: &backbone
+  <<: *FULL_FIELD
+  log_to_wandb: !!bool True
+  lr: 5E-4
+  batch_size: 64
+  max_epochs: 150
+  scheduler: 'CosineAnnealingLR'
+  in_channels: [0, 1 ,2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19]
+  out_channels: [0, 1 ,2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19]
+  orography: !!bool False
+  orography_path: None 
+  exp_dir: '/pscratch/sd/s/shas1693/results/era5_wind'
+  train_data_path: '/pscratch/sd/s/shas1693/data/era5/train'
+  valid_data_path: '/pscratch/sd/s/shas1693/data/era5/test'
+  inf_data_path:   '/pscratch/sd/s/shas1693/data/era5/out_of_sample'
+  time_means_path:   '/pscratch/sd/s/shas1693/data/era5/time_means.npy'
+  global_means_path: '/pscratch/sd/s/shas1693/data/era5/global_means.npy'
+  global_stds_path:  '/pscratch/sd/s/shas1693/data/era5/global_stds.npy'
+
+afno_backbone_orography: &backbone_orography 
+  <<: *backbone
+  orography: !!bool True
+  orography_path: '/pscratch/sd/s/shas1693/data/era5/static/orography.h5'
+
+afno_backbone_finetune: 
+  <<: *backbone
+  lr: 1E-4
+  batch_size: 64
+  log_to_wandb: !!bool True
+  max_epochs: 50
+  pretrained: !!bool True
+  two_step_training: !!bool True
+  pretrained_ckpt_path: '/pscratch/sd/s/shas1693/results/era5_wind/afno_backbone/0/training_checkpoints/best_ckpt.tar'
+
+perturbations:
+  <<: *backbone
+  lr: 1E-4
+  batch_size: 64
+  max_epochs: 50
+  pretrained: !!bool True
+  two_step_training: !!bool True
+  pretrained_ckpt_path: '/pscratch/sd/j/jpathak/ERA5_expts_gtc/wind/afno_20ch_bs_64_lr5em4_blk_8_patch_8_cosine_sched/1/training_checkpoints/best_ckpt.tar'
+  prediction_length: 24
+  ics_type: "datetime"
+  n_perturbations: 100 
+  save_channel: !bool True
+  save_idx: 4
+  save_raw_forecasts: !!bool False
+  date_strings: ["2018-01-01 00:00:00"] 
+  inference_file_tag: " "
+  valid_data_path: "/pscratch/sd/j/jpathak/ "
+  perturb: !!bool True
+  n_level: 0.3
+
+### PRECIP ###
+precip: &precip
+  <<: *backbone
+  in_channels: [0, 1 ,2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19]
+  out_channels: [0]
+  nettype: 'afno'
+  nettype_wind: 'afno'
+  log_to_wandb: !!bool True
+  lr: 2.5E-4
+  batch_size: 64
+  max_epochs: 25
+  precip: '/pscratch/sd/p/pharring/ERA5/precip/total_precipitation'
+  time_means_path_tp: '/pscratch/sd/p/pharring/ERA5/precip/total_precipitation/time_means.npy'
+  model_wind_path: '/pscratch/sd/s/shas1693/results/era5_wind/afno_backbone_finetune/0/training_checkpoints/best_ckpt.tar'
+  precip_eps: !!float 1e-5
+
diff --git a/networks/__pycache__/afnonet.cpython-38.pyc b/networks/__pycache__/afnonet.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..52e1f283c268f643e47d3571e726bd962162820d
GIT binary patch
literal 9351
zcmb7KYiuM}R<2vGu70@P<MDeQ=}fYj>B+>NG0RJY$zvu7fvhtmSy(l@)VBLpySrWe
zm|NvI?rjUSp2&_C!=l}Vg;h{nLcj=E_`&j<U;INt{Hz}c#E-~ANbm!#6p0PrIaU4e
z%;Z(Ad+NUH*16}L`<+w$<*BKRhU<;hKl8d*HSIsBFnQT1JdYF{3Qd!m6uwrbyY7pY
zUe*QW4c};)Ws~!!Z?){Q-Aa{HT(*3tl`f|_Z~K{6ww&d>@N?xUo-bd{^L@HiC>N5p
zGp&NI39U6#KFc*Jf3`JOo@>pQ=aU`_Jj(GGTTA67u1ovNt#jpboX_~@TNla~S{KU~
zxt#SMYZc2y&gcBcTbIh0g!YOir)2(~Ci7mQ`NY0aeiAj)vVfXGQu77WoRKrAnMrD%
zLd{t@i<;S_<}zyL<UDHTlbS22S&)mUSxjn5-c`A@rGHzfd)^hfEYHdF@`Aj0WR{=y
z@`G#gF_eq)@nf_6jQ6zn%)ThEc&pw@^ZLFJ8glu;4R7UGP)Uk6wbG>%+N1T7h*I5(
z3L6z4+3f2|ddicx8`ZF+N7hS9byVa?&+Aq@-Cgt$Z-N-zibS(>TXnj3Dq%gE3e|SI
z>b5(oRq<W)$h_&@-S*nav~$(ESKaD%8tu>x#tnLuR$fJ^N_*3bY~^*mN{F%PH-?oc
zOx^KRqaz#DI~BPmtF-3H%S;Nvb!0rPETmS}rI5linx@oi=7Cld(var8blH-YwC`zU
z`#>|batb_1MaGx>?)nGA8LK3!!;A7Xy@3?0gM`|?HW0@Us=m=J^u?ah*Sm|j>!Cg{
z`l7EN>wV2N`{ureyT!G7-{4vsrBvTK7HDy#_NI1Ute4V}6>N7saM*Pl?MCRjk=|}c
z_G^tcSP&Vt>W#>#R71+GMJZx;XFH5ciA>3k%)s+&QMS=`YvA0r@`7l(-jLFpsK{Z!
zxH*l<T)(pGDRmBSbRW!i-Kt*+0@uB-eRzAd-f4NO)px?xosQaCt#!6lwGxDFFI;^Y
z`CG`Zd*Mzatamnookn{VEz~nut<>5b6jr*sY8Epu64UY^n#c$TzdWa24*XsjzHC*e
z4#`D%V)ZSg;AcSk+Cb}zW3dO84a6v?k8)sn(mJezT&o!jEtzToH$|)97gbBh-PhH5
z+#>OgdV(65JV`aFT1V|vlzbe8AO}LL9{ze+T65B)<j`KVD=iOum2urxM{fI+&$;g1
z?TR0_s5zR|eC4IrzRcWEm#I9Ju)6B`zUxk`fgY=9A*ujED-THU9I=?@nkovXE^4JK
zcnV}B^$yFrDJ{slA#It0tV8}G>lS1^9oa9uy#D&y?Gp!sKKRD$4`2EB-o58eV3W0n
z<em(=ffRfQG9mYn@p+h!Ej_pz>Jn`K0Zeya|AF?txi7wnO@eeAeS@`tbiwLtO_Z9}
zGHHa}H$kBVppK+`I?N2RXitZ^!IY%ll<Xuqv)ZwVM6>2UwkP#6l3vrF)Qe^v%~yDI
z&(Tb0NJmTdx+4Dca*0|6jh<KRT`C6k&bBX$8(vX1-fjdNzE|AXEw;B??uOr~ZUx1J
zzu$fCgygz3qt4>SvJmB82m%iVx6_WTYj$GpXeNL$4jSRE8`h!E^^Py2%y`}?!}Ghe
zj%c|KksbGIv^U+Ff`Vh_Xl7zkw}wiXaVw}+d@q*G^qq==5Qnfm7TT)y$buP`?MUCb
z5$PL9J8QVFm2B1-^+j4fX{GvUB0opu=ZUcFjx>qi@M+Ws%OIphMco$k%QLn0tKunf
zRrID~S$wo9AiI%8x`h;cA96@GwlA6*EUq9643i+5`d7_8bI;nd_fjE@JtPNaU+U60
zGN^q((q{E-X)bHhT7djDQ%KISwI|R&{peQo$joY+n5&tU_L08le2%_p?u(U;MrR(X
z$@Z<`iuQ70uAjpbr)IT%u{Raw2h*W3D99Av1fb|$9j#M~8qo!mq}HE$SFagMTJsD>
z%}8enlK&Y;oaGS~jlhgrbCyOgkZ$2L`uJx(t9_1D#n1RWtBT))SH(LtE0O0_r9Wj=
zF`qth?Mw8YbEmlUiL)mC$2^*kdDPE*rnP*QxhHw^S=K)0%zVt5e)iMW(p>0g`q@og
zrey~5E5dve7zLbIl36YrgXQGld{{?P$0los**u3a(4HEBDP87{P!F9t4>UsuGR}bo
zvX4{iC0!BpQfnZQ)lonOK!%2Fv?7z>AUALqfH_-wZ#Sx5uduSRveCG<(dayLld>JQ
zvL_$mDhWkukIt4~X>G`o5s84S)NQK25Lq(Z?Rt?_^E;IgMy*}}f_j!7Z`5k2P@n)y
zWId^cjn>cd>Uy5)1hLtAiTeYOvMGzi%aN%(e><|hMmyNX+5t>`b(4CTjaFq-J%_A%
zp2!PSX$4`Wx)o)DPAv@Ts?pwJ!x&A!-SBqYO5j#It*-CA6WNV;Jti%>l#i_`9ovz)
z(WnIKRmzGNBk@a-`1MHKDyKR(TzWWX=+Vfm4Qtk1uE~a-5I7s)!wbrpEw3BK`3X$H
zaDre5aV98(Xu^ao1`^X}#6|s@xF%-AjFCgx6l-GbBVBiXqJMPBK2=Ix7uS=P;$+^Q
z5d~QHMf@^)FP~WRAxJXS3ZAD{iLi7Kxo@j0R8b-VKQTTI85*f;ROC~H2_Gu_2G5{@
z`X+!Ve@0PF73tKKPD%m2?FbN&<m%)IAX(%DkONRlGfYZJ2QbMdEdfkQF-&qG(b>p)
zk&f`hCn!9x&S1*7#xImK{P%NJI--Ubfktu|45Xk7G69<GKD;hK7(FB>!Re>_nSBC3
za1;QgO!dXkg3BP<c4!Y$#{`s|{%{Nc5-A%W5ZN$`JPwe)A@w7gA=k%|=E4WY!i{xa
zDWLBKJYHQSQY1n^AA&aqAQtK64Q6Cy*BU-RfFev7BLnAE_zxivNR%cw!|jrP$3xcC
z9qLs8iHw%tjk2<{4d~+rTa9ih7mM~SYE7fnRTUtS7=3TjxFWS==-G(}(xlqcn4+Ag
zkqLHz4S5v2P#6}_BD5)j=moK?XG9Ld+{^RMJsMNc#!#a51#f{crg)@_(bUz|*us&e
z1M<+wyGvXRbCd%{7~4YXIdnd3B{9G9i10hsDc2xEj3^lsM!<oYYDhQ^6Vg9Y-S^Nt
zm<NIC;!QKGz9!z5evar(#qT@Ry7YtkMP7&^G!~^_BcE@b>@j&~pQ3aXbh(Sf&vr~*
z(9gs#&guV~Ur^MiigfBC2GbSLD-6|;)eWk6&w*D$p7^9!VnyjY%BwcIa4W0B;WP#1
zNhs+a5CbO^ij@f%3%k-ZYUF27jue7$8@vbj7*y+oX&I|tCd@VoX%PZU!76Y_Kw3xC
zcPh*d@~n2#L(G32&pV+f^SSg^$&9jFo@#r347}|iR1N7-;6zHmZuJ}JS}IJam<<$P
zX$4UR+EcB^qo&s3<Ga;*rHzS#$Vm!8lwWtrNguVOhb45{z$vo<$68I}5s}%&6Js^C
zZ*|&lugS=);vSiAcyKkz-5p9-NUQG=VPpCP<yf$2|NjUnAV--HI$?A2&~4<;i{8v=
zKOU_}r~GuX+n?D_A1R0#FYFr_toDeopZ*2P@jAzgy^Oj)MG7(?W2d}#X1vHm`qkIb
zVPg3B2<f2V`$)+VlFKTpsUn@a*zHG6@G{6ycI6yPZ&v2mr{-8++4QWY4Kd<4AL49I
z9N#Oa<EUOa!-q*0HhDfu(a8wezOSGB1LkcYeuy~437IdoPlRyO2}$`0EkJD$p%W)=
zU8Ni?waT#+WayekB8FfTY#C;p)Qv;i)%!YVtP{|S8kJzTnilRjSehx^okQK!a8D!Q
zh-qeV&xW~hs+lLmcci~(h0`=o(?qfl$*)4rEud6rp4rQXQ-hiC>|hpteEv{B1};SF
z9LCQd3TPIU7UCcTdM!!^s5EtK+=i#~p4Qh1S<~EfU#4{|q0UU|aOb*ZJmcJfb|4P)
z1LMFvunz2lRL#Meox10Q=liLBLG$sGFC1Ej`aV7NBG&NOA+6jQK4CU~7UrGRWCl3c
zK)+)6INIlAPEH*m2z_V|E}{IyF$GBJK+Ugj9{tyk20#ACzr6`GYIxlMcED-0Hd&q3
zHkhw=K%%tQ+VG@H#}ye4bQawzc4c&t%9u!RA_b&>n^2ms<rq<G$J+1fcTJofPoM!~
z%|__80HO7j>xyvVgiMqEMww^*Mi3Iby!lTkvsBxfuJxWwL~W$DE6)PRRQ%v3GQ%Nk
z&>mVGi5}sD;3^0&?h@Pp4NNo-#UYMLYz%~_DGpPIhSZLnc*!MEa$;z{g>e%{y-&r{
zCp-Lv%CHvsZy_Iv#(g>yx+;MhKn$Kbq+T3f;7NHZ3fskj2Vzoe$9@O7e?j7Soo(np
zDF_p6&_8+^E(955bpzx+b}k5I>B+HiR%{v%s0Z)i+mxe?RPPYk1&MMLav26y<0bLy
z&`z@bRc-2)LGZg;e~Gp~%EqRiWg*Jac&_jQ^$RqJfT?<&2Bo{5z~yd=f+UecWrA9f
z<HPyrR;pek>XF!qGGF58;%nGE4$n}OoZ~+M`$CGKL1flFoCySgBdgNw`n!lTax8hr
zYi!oTpfnxBjR)SEUzpyN-wsX{oQS$&mj51V@o8iT{s_n=ZqzvT^F12*IFa8aGCEoQ
z3Z;O+zjlUBmf_iwW4~iwIQl&NjYWM{%<AV2`~NalM!%xxMX&In@IDk9CIz}hvlIRc
zA(J5jE&IB9kx(@Z2^%dj+0rJHC$+7#YNIc(HJ1^N7h3yTGJ}H|9%>!+Hl<0u*&JBw
zEKus(csy$n+hv4RGmQY-7FOETIJU;y5^snd?*WS=6dMu#D2?-)I)^Tc5LAJLI})u2
z^jr;4{jW$#6dit=m<Pzs0r2+bM%<d*=~HS=_J^dESw=vL7*y9|N8t}pE~l%0;D((o
zubr5lSWZ~OMrsyKw>xh97OwQFmE8|8G@yw!y|*x2#shs$?H%nKc}BdDUqc4FfW<6n
z46Fg06A-O65mJCj9D{8!FbT$uLc$!JK-sxX#*@O#0C#)h?rnnJB|Q$+meT51@DTM~
zBJ`f>4~eW3A=sn-m<Su0_mP9U#Kwg!?N?DzHa7gN2lgT=Bp_pC68;lCV;_XEIAQa7
z%IUquAx|D0a*8R$KvG#|%2$wi6o?8DN0bz~fxE}CDZppG<HJ@`#+@ST5dXM^dm!Fv
z9@()+QnJ>2*EWjv;!g2O_eQb1R;&`(EES)*S=@;8#Vhq<_eQCBWe3frShV>ZGb6o=
zKyY_$%(#-J{sOC2e?lZjgiR<L(~rY%KSsC2iy^-y_8|yxBK+{7?ZlyAK%~(pmYn)+
zn)#bV2p+}c9_8L8vO$C~(j$Np{Lm}ZULr!rS?ur(C`UFpGMdmryRivJ)Kh;>)%-@6
zDR+&?SOM6+v1$7)H2ftJJ2sAmPk@j5{ZD?)Ndrb3PEOH?D$=Qobb+qZ-)-W61bv&N
z?=}>enDme`>N^+#%=mA(sUH$y&U7h9OsQHR!;3VHl7Yl-?j$rMtpm&gh0wqpaEV1|
zR@;N?1ZN1)5!B#MG-e2#3oRopWcY`c6x%{`Jr3)Tk%0xfi<=^|3^+hAB4#-Y1&aqC
z7IW0l2Z0x(I?dsSg>f%FkytUgKQo{Af_H#mla1$%KF<~Tc(FzA?D!BKL7!8~M*JYS
zd<!X{V~{KmJ|5_M27N%Fk5Wx2+LE?Kjz4{XV&32**H`e73-6DQT=<j#8n>Y**ChRr
zx8G0m=dZ+e4NRFj76S{8X6hIrShi12Y@e8My_r|r-R-bQ!91_XAC;~|;%aZ9RFoYr
zXkQMCEkYJWitzcxD?Pom!dd{Xv2|mv{S;nB{S64dU;33W^x9E6#8;dEfWjN{cVyU<
zOWnUl<Tr^hYsdU0`Og!79b_~dHouI}9-P4uz!|Bhss5VApB6&(b*dZB$-BoJN78H2
z#3cNPJ~}MV^cA&LfeY-G^pk(q;r<LD#>i7I_4h<bz)xu3V51+962>iLqtxFK`4M$m
zi%jsW9;IveCqmfi_<_Q-T9mDJeBa}*Vu3QK1}-zcQuBjIbs`;xfMldR1Ju^tUGW<m
zEA+2~U2?Wl)WhNie4FpaHpJmiW-IMBW5M#-VR40@jqh(2Kgx{r>Q{O8u+eI*@Qwy6
z{0%;c!)|mCl+*DXh+UVLN7JD*zAgtVzziXng*YAYBbtvpY=;dWZKJ<ds4O*71iRA>
zV(UCbQ=g4%l4q{)V3nu7)A1DK({L`1GT1Tf<Uk)s31j5m!3Z{Myn$?GBa2^0{USd&
zs8Gn8{>Tw*`#`K+gvG;HUT@Rbv2={YE^?lY-|;5xOhB;T0mLDYLzGSlME3Lo-;-Kn
U=IqQXGnYZjGX;wOQENv0KV@JWN&o-=

literal 0
HcmV?d00001

diff --git a/networks/afnonet.py b/networks/afnonet.py
new file mode 100644
index 0000000..8d05ec0
--- /dev/null
+++ b/networks/afnonet.py
@@ -0,0 +1,283 @@
+#reference: https://github.com/NVlabs/AFNO-transformer
+
+import math
+from functools import partial
+from collections import OrderedDict
+from copy import Error, deepcopy
+from re import S
+from numpy.lib.arraypad import pad
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+#from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
+from timm.models.layers import DropPath, trunc_normal_
+import torch.fft
+from torch.nn.modules.container import Sequential
+from torch.utils.checkpoint import checkpoint_sequential
+from einops import rearrange, repeat
+from einops.layers.torch import Rearrange
+from utils.img_utils import PeriodicPad2d
+
+
+class Mlp(nn.Module):
+    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
+        super().__init__()
+        out_features = out_features or in_features
+        hidden_features = hidden_features or in_features
+        self.fc1 = nn.Linear(in_features, hidden_features)
+        self.act = act_layer()
+        self.fc2 = nn.Linear(hidden_features, out_features)
+        self.drop = nn.Dropout(drop)
+
+    def forward(self, x):
+        x = self.fc1(x)
+        x = self.act(x)
+        x = self.drop(x)
+        x = self.fc2(x)
+        x = self.drop(x)
+        return x
+
+
+class AFNO2D(nn.Module):
+    def __init__(self, hidden_size, num_blocks=8, sparsity_threshold=0.01, hard_thresholding_fraction=1, hidden_size_factor=1):
+        super().__init__()
+        assert hidden_size % num_blocks == 0, f"hidden_size {hidden_size} should be divisble by num_blocks {num_blocks}"
+
+        self.hidden_size = hidden_size
+        self.sparsity_threshold = sparsity_threshold
+        self.num_blocks = num_blocks
+        self.block_size = self.hidden_size // self.num_blocks
+        self.hard_thresholding_fraction = hard_thresholding_fraction
+        self.hidden_size_factor = hidden_size_factor
+        self.scale = 0.02
+
+        self.w1 = nn.Parameter(self.scale * torch.randn(2, self.num_blocks, self.block_size, self.block_size * self.hidden_size_factor))
+        self.b1 = nn.Parameter(self.scale * torch.randn(2, self.num_blocks, self.block_size * self.hidden_size_factor))
+        self.w2 = nn.Parameter(self.scale * torch.randn(2, self.num_blocks, self.block_size * self.hidden_size_factor, self.block_size))
+        self.b2 = nn.Parameter(self.scale * torch.randn(2, self.num_blocks, self.block_size))
+
+    def forward(self, x):
+        bias = x
+
+        dtype = x.dtype
+        x = x.float()
+        B, H, W, C = x.shape
+
+        x = torch.fft.rfft2(x, dim=(1, 2), norm="ortho")
+        x = x.reshape(B, H, W // 2 + 1, self.num_blocks, self.block_size)
+
+        o1_real = torch.zeros([B, H, W // 2 + 1, self.num_blocks, self.block_size * self.hidden_size_factor], device=x.device)
+        o1_imag = torch.zeros([B, H, W // 2 + 1, self.num_blocks, self.block_size * self.hidden_size_factor], device=x.device)
+        o2_real = torch.zeros(x.shape, device=x.device)
+        o2_imag = torch.zeros(x.shape, device=x.device)
+
+
+        total_modes = H // 2 + 1
+        kept_modes = int(total_modes * self.hard_thresholding_fraction)
+
+        o1_real[:, total_modes-kept_modes:total_modes+kept_modes, :kept_modes] = F.relu(
+            torch.einsum('...bi,bio->...bo', x[:, total_modes-kept_modes:total_modes+kept_modes, :kept_modes].real, self.w1[0]) - \
+            torch.einsum('...bi,bio->...bo', x[:, total_modes-kept_modes:total_modes+kept_modes, :kept_modes].imag, self.w1[1]) + \
+            self.b1[0]
+        )
+
+        o1_imag[:, total_modes-kept_modes:total_modes+kept_modes, :kept_modes] = F.relu(
+            torch.einsum('...bi,bio->...bo', x[:, total_modes-kept_modes:total_modes+kept_modes, :kept_modes].imag, self.w1[0]) + \
+            torch.einsum('...bi,bio->...bo', x[:, total_modes-kept_modes:total_modes+kept_modes, :kept_modes].real, self.w1[1]) + \
+            self.b1[1]
+        )
+
+        o2_real[:, total_modes-kept_modes:total_modes+kept_modes, :kept_modes]  = (
+            torch.einsum('...bi,bio->...bo', o1_real[:, total_modes-kept_modes:total_modes+kept_modes, :kept_modes], self.w2[0]) - \
+            torch.einsum('...bi,bio->...bo', o1_imag[:, total_modes-kept_modes:total_modes+kept_modes, :kept_modes], self.w2[1]) + \
+            self.b2[0]
+        )
+
+        o2_imag[:, total_modes-kept_modes:total_modes+kept_modes, :kept_modes]  = (
+            torch.einsum('...bi,bio->...bo', o1_imag[:, total_modes-kept_modes:total_modes+kept_modes, :kept_modes], self.w2[0]) + \
+            torch.einsum('...bi,bio->...bo', o1_real[:, total_modes-kept_modes:total_modes+kept_modes, :kept_modes], self.w2[1]) + \
+            self.b2[1]
+        )
+
+        x = torch.stack([o2_real, o2_imag], dim=-1)
+        x = F.softshrink(x, lambd=self.sparsity_threshold)
+        x = torch.view_as_complex(x)
+        x = x.reshape(B, H, W // 2 + 1, C)
+        x = torch.fft.irfft2(x, s=(H, W), dim=(1,2), norm="ortho")
+        x = x.type(dtype)
+
+        return x + bias
+
+
+class Block(nn.Module):
+    def __init__(
+            self,
+            dim,
+            mlp_ratio=4.,
+            drop=0.,
+            drop_path=0.,
+            act_layer=nn.GELU,
+            norm_layer=nn.LayerNorm,
+            double_skip=True,
+            num_blocks=8,
+            sparsity_threshold=0.01,
+            hard_thresholding_fraction=1.0
+        ):
+        super().__init__()
+        self.norm1 = norm_layer(dim)
+        self.filter = AFNO2D(dim, num_blocks, sparsity_threshold, hard_thresholding_fraction) 
+        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+        #self.drop_path = nn.Identity()
+        self.norm2 = norm_layer(dim)
+        mlp_hidden_dim = int(dim * mlp_ratio)
+        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
+        self.double_skip = double_skip
+
+    def forward(self, x):
+        residual = x
+        x = self.norm1(x)
+        x = self.filter(x)
+
+        if self.double_skip:
+            x = x + residual
+            residual = x
+
+        x = self.norm2(x)
+        x = self.mlp(x)
+        x = self.drop_path(x)
+        x = x + residual
+        return x
+
+class PrecipNet(nn.Module):
+    def __init__(self, params, backbone):
+        super().__init__()
+        self.params = params
+        self.patch_size = (params.patch_size, params.patch_size)
+        self.in_chans = params.N_in_channels
+        self.out_chans = params.N_out_channels
+        self.backbone = backbone
+        self.ppad = PeriodicPad2d(1)
+        self.conv = nn.Conv2d(self.out_chans, self.out_chans, kernel_size=3, stride=1, padding=0, bias=True)
+        self.act = nn.ReLU()
+
+    def forward(self, x):
+        x = self.backbone(x)
+        x = self.ppad(x)
+        x = self.conv(x)
+        x = self.act(x)
+        return x
+
+class AFNONet(nn.Module):
+    def __init__(
+            self,
+            params,
+            # img_size=(720, 1440), modified by cxt 2023.12.31
+            img_size=(192, 288),
+            patch_size=(16, 16),
+            in_chans=2,
+            out_chans=2,
+            embed_dim=768,
+            depth=12,
+            mlp_ratio=4.,
+            drop_rate=0.,
+            drop_path_rate=0.,
+            num_blocks=16,
+            sparsity_threshold=0.01,
+            hard_thresholding_fraction=1.0,
+        ):
+        super().__init__()
+        self.params = params
+        self.img_size = img_size
+        self.patch_size = (params.patch_size, params.patch_size)
+        self.in_chans = params.N_in_channels
+        self.out_chans = params.N_out_channels
+        self.num_features = self.embed_dim = embed_dim
+        self.num_blocks = params.num_blocks 
+        norm_layer = partial(nn.LayerNorm, eps=1e-6)
+
+        self.patch_embed = PatchEmbed(img_size=img_size, patch_size=self.patch_size, in_chans=self.in_chans, embed_dim=embed_dim)
+        num_patches = self.patch_embed.num_patches
+
+        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
+        self.pos_drop = nn.Dropout(p=drop_rate)
+
+        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
+
+        self.h = img_size[0] // self.patch_size[0]
+        self.w = img_size[1] // self.patch_size[1]
+
+        self.blocks = nn.ModuleList([
+            Block(dim=embed_dim, mlp_ratio=mlp_ratio, drop=drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
+            num_blocks=self.num_blocks, sparsity_threshold=sparsity_threshold, hard_thresholding_fraction=hard_thresholding_fraction) 
+        for i in range(depth)])
+
+        self.norm = norm_layer(embed_dim)
+
+        self.head = nn.Linear(embed_dim, self.out_chans*self.patch_size[0]*self.patch_size[1], bias=False)
+
+        trunc_normal_(self.pos_embed, std=.02)
+        self.apply(self._init_weights)
+
+    def _init_weights(self, m):
+        if isinstance(m, nn.Linear):
+            trunc_normal_(m.weight, std=.02)
+            if isinstance(m, nn.Linear) and m.bias is not None:
+                nn.init.constant_(m.bias, 0)
+        elif isinstance(m, nn.LayerNorm):
+            nn.init.constant_(m.bias, 0)
+            nn.init.constant_(m.weight, 1.0)
+
+    @torch.jit.ignore
+    def no_weight_decay(self):
+        return {'pos_embed', 'cls_token'}
+
+    def forward_features(self, x):
+        B = x.shape[0]
+        x = self.patch_embed(x)
+        x = x + self.pos_embed
+        x = self.pos_drop(x)
+        
+        x = x.reshape(B, self.h, self.w, self.embed_dim)
+        for blk in self.blocks:
+            x = blk(x)
+
+        return x
+
+    def forward(self, x):
+        x = self.forward_features(x)
+        x = self.head(x)
+        x = rearrange(
+            x,
+            "b h w (p1 p2 c_out) -> b c_out (h p1) (w p2)",
+            p1=self.patch_size[0],
+            p2=self.patch_size[1],
+            h=self.img_size[0] // self.patch_size[0],
+            w=self.img_size[1] // self.patch_size[1],
+        )
+        return x
+
+
+class PatchEmbed(nn.Module):
+    def __init__(self, img_size=(224, 224), patch_size=(16, 16), in_chans=3, embed_dim=768):
+        super().__init__()
+        num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
+        self.img_size = img_size
+        self.patch_size = patch_size
+        self.num_patches = num_patches
+        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
+
+    def forward(self, x):
+        B, C, H, W = x.shape
+        assert H == self.img_size[0] and W == self.img_size[1], f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
+        x = self.proj(x).flatten(2).transpose(1, 2)
+        return x
+
+
+if __name__ == "__main__":
+    model = AFNONet(img_size=(720, 1440), patch_size=(4,4), in_chans=3, out_chans=10)
+    sample = torch.randn(1, 3, 720, 1440)
+    result = model(sample)
+    print(result.shape)
+    print(torch.norm(result))
+
diff --git a/train.py b/train.py
new file mode 100644
index 0000000..89e33e7
--- /dev/null
+++ b/train.py
@@ -0,0 +1,616 @@
+#BSD 3-Clause License
+#
+#Copyright (c) 2022, FourCastNet authors
+#All rights reserved.
+#
+#Redistribution and use in source and binary forms, with or without
+#modification, are permitted provided that the following conditions are met:
+#
+#1. Redistributions of source code must retain the above copyright notice, this
+#   list of conditions and the following disclaimer.
+#
+#2. Redistributions in binary form must reproduce the above copyright notice,
+#   this list of conditions and the following disclaimer in the documentation
+#   and/or other materials provided with the distribution.
+#
+#3. Neither the name of the copyright holder nor the names of its
+#   contributors may be used to endorse or promote products derived from
+#   this software without specific prior written permission.
+#
+#THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+#AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+#IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+#DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+#FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+#DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+#SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+#CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+#OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+#OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+#
+#The code was authored by the following people:
+#
+#Jaideep Pathak - NVIDIA Corporation
+#Shashank Subramanian - NERSC, Lawrence Berkeley National Laboratory
+#Peter Harrington - NERSC, Lawrence Berkeley National Laboratory
+#Sanjeev Raja - NERSC, Lawrence Berkeley National Laboratory 
+#Ashesh Chattopadhyay - Rice University 
+#Morteza Mardani - NVIDIA Corporation 
+#Thorsten Kurth - NVIDIA Corporation 
+#David Hall - NVIDIA Corporation 
+#Zongyi Li - California Institute of Technology, NVIDIA Corporation 
+#Kamyar Azizzadenesheli - Purdue University 
+#Pedram Hassanzadeh - Rice University 
+#Karthik Kashinath - NVIDIA Corporation 
+#Animashree Anandkumar - California Institute of Technology, NVIDIA Corporation
+
+import os
+import time
+import numpy as np
+import argparse
+import h5py
+import torch
+import cProfile
+import re
+import torchvision
+from torchvision.utils import save_image
+import torch.nn as nn
+import torch.cuda.amp as amp
+import torch.distributed as dist
+from torch.nn.parallel import DistributedDataParallel
+import logging
+from utils import logging_utils
+logging_utils.config_logger()
+from utils.YParams import YParams
+from utils.data_loader_multifiles import get_data_loader
+from networks.afnonet import AFNONet, PrecipNet
+from utils.img_utils import vis_precip
+import wandb
+from utils.weighted_acc_rmse import weighted_acc, weighted_rmse, weighted_rmse_torch, unlog_tp_torch
+from apex import optimizers
+from utils.darcy_loss import LpLoss
+import matplotlib.pyplot as plt
+from collections import OrderedDict
+import pickle
+DECORRELATION_TIME = 36 # 9 days
+import json
+from ruamel.yaml import YAML
+from ruamel.yaml.comments import CommentedMap as ruamelDict
+
+class Trainer():
+  def count_parameters(self):
+    return sum(p.numel() for p in self.model.parameters() if p.requires_grad)
+
+  def __init__(self, params, world_rank):
+    
+    self.params = params
+    self.world_rank = world_rank
+    self.device = torch.cuda.current_device() if torch.cuda.is_available() else 'cpu'
+
+    if params.log_to_wandb:
+      wandb.init(config=params, name=params.name, group=params.group, project=params.project)
+
+    logging.info('rank %d, begin data loader init'%world_rank)
+    self.train_data_loader, self.train_dataset, self.train_sampler = get_data_loader(params, params.train_data_path, dist.is_initialized(), train=True)
+    self.valid_data_loader, self.valid_dataset = get_data_loader(params, params.valid_data_path, dist.is_initialized(), train=False)
+    self.loss_obj = LpLoss()
+    logging.info('rank %d, data loader initialized'%world_rank)
+
+    params.crop_size_x = self.valid_dataset.crop_size_x
+    params.crop_size_y = self.valid_dataset.crop_size_y
+    params.img_shape_x = self.valid_dataset.img_shape_x
+    params.img_shape_y = self.valid_dataset.img_shape_y
+
+    # precip models
+    self.precip = True if "precip" in params else False
+    
+    if self.precip:
+      if 'model_wind_path' not in params:
+        raise Exception("no backbone model weights specified")
+      # load a wind model 
+      # the wind model has out channels = in channels
+      out_channels = np.array(params['in_channels'])
+      params['N_out_channels'] = len(out_channels)
+
+      if params.nettype_wind == 'afno':
+        self.model_wind = AFNONet(params).to(self.device)
+      else:
+        raise Exception("not implemented")
+
+      if dist.is_initialized():
+        self.model_wind = DistributedDataParallel(self.model_wind,
+                                            device_ids=[params.local_rank],
+                                            output_device=[params.local_rank],find_unused_parameters=True)
+      self.load_model_wind(params.model_wind_path)
+      self.switch_off_grad(self.model_wind) # no backprop through the wind model
+
+
+    # reset out_channels for precip models
+    if self.precip:
+      params['N_out_channels'] = len(params['out_channels'])
+
+    if params.nettype == 'afno':
+      self.model = AFNONet(params).to(self.device) 
+    else:
+      raise Exception("not implemented")
+     
+    # precip model
+    if self.precip:
+      self.model = PrecipNet(params, backbone=self.model).to(self.device)
+
+    if self.params.enable_nhwc:
+      # NHWC: Convert model to channels_last memory format
+      self.model = self.model.to(memory_format=torch.channels_last)
+
+    if params.log_to_wandb:
+      wandb.watch(self.model)
+
+    if params.optimizer_type == 'FusedAdam':
+      self.optimizer = optimizers.FusedAdam(self.model.parameters(), lr = params.lr)
+    else:
+      self.optimizer = torch.optim.Adam(self.model.parameters(), lr = params.lr)
+
+    if params.enable_amp == True:
+      self.gscaler = amp.GradScaler()
+
+    if dist.is_initialized():
+      self.model = DistributedDataParallel(self.model,
+                                           device_ids=[params.local_rank],
+                                           output_device=[params.local_rank],find_unused_parameters=True)
+
+    self.iters = 0
+    self.startEpoch = 0
+    if params.resuming:
+      logging.info("Loading checkpoint %s"%params.checkpoint_path)
+      self.restore_checkpoint(params.checkpoint_path)
+    if params.two_step_training:
+      if params.resuming == False and params.pretrained == True:
+        logging.info("Starting from pretrained one-step afno model at %s"%params.pretrained_ckpt_path)
+        self.restore_checkpoint(params.pretrained_ckpt_path)
+        self.iters = 0
+        self.startEpoch = 0
+        #logging.info("Pretrained checkpoint was trained for %d epochs"%self.startEpoch)
+        #logging.info("Adding %d epochs specified in config file for refining pretrained model"%self.params.max_epochs)
+        #self.params.max_epochs += self.startEpoch
+
+            
+    self.epoch = self.startEpoch
+
+    if params.scheduler == 'ReduceLROnPlateau':
+      self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, factor=0.2, patience=5, mode='min')
+    elif params.scheduler == 'CosineAnnealingLR':
+      self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer, T_max=params.max_epochs, last_epoch=self.startEpoch-1)
+    else:
+      self.scheduler = None
+
+    '''if params.log_to_screen:
+      logging.info(self.model)'''
+    if params.log_to_screen:
+      logging.info("Number of trainable model parameters: {}".format(self.count_parameters()))
+
+  def switch_off_grad(self, model):
+    for param in model.parameters():
+      param.requires_grad = False
+
+  def train(self):
+    if self.params.log_to_screen:
+      logging.info("Starting Training Loop...")
+
+    best_valid_loss = 1.e6
+    for epoch in range(self.startEpoch, self.params.max_epochs):
+      if dist.is_initialized():
+        self.train_sampler.set_epoch(epoch)
+#        self.valid_sampler.set_epoch(epoch)
+
+      start = time.time()
+      tr_time, data_time, train_logs = self.train_one_epoch()
+      valid_time, valid_logs = self.validate_one_epoch()
+      if epoch==self.params.max_epochs-1 and self.params.prediction_type == 'direct':
+        valid_weighted_rmse = self.validate_final()
+
+
+
+      if self.params.scheduler == 'ReduceLROnPlateau':
+        self.scheduler.step(valid_logs['valid_loss'])
+      elif self.params.scheduler == 'CosineAnnealingLR':
+        self.scheduler.step()
+        if self.epoch >= self.params.max_epochs:
+          logging.info("Terminating training after reaching params.max_epochs while LR scheduler is set to CosineAnnealingLR")
+          exit()
+
+      if self.params.log_to_wandb:
+        for pg in self.optimizer.param_groups:
+          lr = pg['lr']
+        wandb.log({'lr': lr})
+
+      if self.world_rank == 0:
+        if self.params.save_checkpoint:
+          #checkpoint at the end of every epoch
+          self.save_checkpoint(self.params.checkpoint_path)
+          if valid_logs['valid_loss'] <= best_valid_loss:
+            #logging.info('Val loss improved from {} to {}'.format(best_valid_loss, valid_logs['valid_loss']))
+            self.save_checkpoint(self.params.best_checkpoint_path)
+            best_valid_loss = valid_logs['valid_loss']
+
+      if self.params.log_to_screen:
+        logging.info('Time taken for epoch {} is {} sec'.format(epoch + 1, time.time()-start))
+        #logging.info('train data time={}, train step time={}, valid step time={}'.format(data_time, tr_time, valid_time))
+        logging.info('Train loss: {}. Valid loss: {}'.format(train_logs['loss'], valid_logs['valid_loss']))
+#        if epoch==self.params.max_epochs-1 and self.params.prediction_type == 'direct':
+#          logging.info('Final Valid RMSE: Z500- {}. T850- {}, 2m_T- {}'.format(valid_weighted_rmse[0], valid_weighted_rmse[1], valid_weighted_rmse[2]))
+
+
+
+  def train_one_epoch(self):
+    self.epoch += 1
+    tr_time = 0
+    data_time = 0
+    self.model.train()
+    
+    for i, data in enumerate(self.train_data_loader, 0):
+      self.iters += 1
+      # adjust_LR(optimizer, params, iters)
+      data_start = time.time()
+      inp, tar = map(lambda x: x.to(self.device, dtype = torch.float), data)      
+      if self.params.orography and self.params.two_step_training:
+          orog = inp[:,-2:-1] 
+
+
+      if self.params.enable_nhwc:
+        inp = inp.to(memory_format=torch.channels_last)
+        tar = tar.to(memory_format=torch.channels_last)
+
+
+      if 'residual_field' in self.params.target:
+        tar -= inp[:, 0:tar.size()[1]]
+      data_time += time.time() - data_start
+
+      tr_start = time.time()
+
+      self.model.zero_grad()
+      if self.params.two_step_training:
+          with amp.autocast(self.params.enable_amp):
+            gen_step_one = self.model(inp).to(self.device, dtype = torch.float)
+            loss_step_one = self.loss_obj(gen_step_one, tar[:,0:self.params.N_out_channels])
+            if self.params.orography:
+                gen_step_two = self.model(torch.cat( (gen_step_one, orog), axis = 1)  ).to(self.device, dtype = torch.float)
+            else:
+                gen_step_two = self.model(gen_step_one).to(self.device, dtype = torch.float)
+            loss_step_two = self.loss_obj(gen_step_two, tar[:,self.params.N_out_channels:2*self.params.N_out_channels])
+            loss = loss_step_one + loss_step_two
+      else:
+          with amp.autocast(self.params.enable_amp):
+            if self.precip: # use a wind model to predict 17(+n) channels at t+dt
+              with torch.no_grad():
+                inp = self.model_wind(inp).to(self.device, dtype = torch.float)
+              gen = self.model(inp.detach()).to(self.device, dtype = torch.float)
+            else:
+              gen = self.model(inp).to(self.device, dtype = torch.float)
+            loss = self.loss_obj(gen, tar)
+
+      if self.params.enable_amp:
+        self.gscaler.scale(loss).backward()
+        self.gscaler.step(self.optimizer)
+      else:
+        loss.backward()
+        self.optimizer.step()
+
+      if self.params.enable_amp:
+        self.gscaler.update()
+
+      tr_time += time.time() - tr_start
+    
+    try:
+        logs = {'loss': loss, 'loss_step_one': loss_step_one, 'loss_step_two': loss_step_two}
+    except:
+        logs = {'loss': loss}
+
+    if dist.is_initialized():
+      for key in sorted(logs.keys()):
+        dist.all_reduce(logs[key].detach())
+        logs[key] = float(logs[key]/dist.get_world_size())
+
+    if self.params.log_to_wandb:
+      wandb.log(logs, step=self.epoch)
+
+    return tr_time, data_time, logs
+
+  def validate_one_epoch(self):
+    self.model.eval()
+    n_valid_batches = 20 #do validation on first 20 images, just for LR scheduler
+    if self.params.normalization == 'minmax':
+        raise Exception("minmax normalization not supported")
+    elif self.params.normalization == 'zscore':
+        mult = torch.as_tensor(np.load(self.params.global_stds_path)[0, self.params.out_channels, 0, 0]).to(self.device)
+
+    valid_buff = torch.zeros((3), dtype=torch.float32, device=self.device)
+    valid_loss = valid_buff[0].view(-1)
+    valid_l1 = valid_buff[1].view(-1)
+    valid_steps = valid_buff[2].view(-1)
+    valid_weighted_rmse = torch.zeros((self.params.N_out_channels), dtype=torch.float32, device=self.device)
+    valid_weighted_acc = torch.zeros((self.params.N_out_channels), dtype=torch.float32, device=self.device)
+
+    valid_start = time.time()
+
+    sample_idx = np.random.randint(len(self.valid_data_loader))
+    with torch.no_grad():
+      for i, data in enumerate(self.valid_data_loader, 0):
+        if (not self.precip) and i>=n_valid_batches:
+          break    
+        inp, tar  = map(lambda x: x.to(self.device, dtype = torch.float), data)
+        if self.params.orography and self.params.two_step_training:
+            orog = inp[:,-2:-1]
+
+        if self.params.two_step_training:
+            gen_step_one = self.model(inp).to(self.device, dtype = torch.float)
+            loss_step_one = self.loss_obj(gen_step_one, tar[:,0:self.params.N_out_channels])
+
+            if self.params.orography:
+                gen_step_two = self.model(torch.cat( (gen_step_one, orog), axis = 1)  ).to(self.device, dtype = torch.float)
+            else:
+                gen_step_two = self.model(gen_step_one).to(self.device, dtype = torch.float)
+
+            loss_step_two = self.loss_obj(gen_step_two, tar[:,self.params.N_out_channels:2*self.params.N_out_channels])
+            valid_loss += loss_step_one + loss_step_two
+            valid_l1 += nn.functional.l1_loss(gen_step_one, tar[:,0:self.params.N_out_channels])
+        else:
+            if self.precip:
+                with torch.no_grad():
+                    inp = self.model_wind(inp).to(self.device, dtype = torch.float)
+                gen = self.model(inp.detach())
+            else:
+                gen = self.model(inp).to(self.device, dtype = torch.float)
+            valid_loss += self.loss_obj(gen, tar) 
+            valid_l1 += nn.functional.l1_loss(gen, tar)
+
+        valid_steps += 1.
+        # save fields for vis before log norm 
+        if (i == sample_idx) and (self.precip and self.params.log_to_wandb):
+          fields = [gen[0,0].detach().cpu().numpy(), tar[0,0].detach().cpu().numpy()]
+
+        if self.precip:
+          gen = unlog_tp_torch(gen, self.params.precip_eps)
+          tar = unlog_tp_torch(tar, self.params.precip_eps)
+
+        #direct prediction weighted rmse
+        if self.params.two_step_training:
+            if 'residual_field' in self.params.target:
+                valid_weighted_rmse += weighted_rmse_torch((gen_step_one + inp), (tar[:,0:self.params.N_out_channels] + inp))
+            else:
+                valid_weighted_rmse += weighted_rmse_torch(gen_step_one, tar[:,0:self.params.N_out_channels])
+        else:
+            if 'residual_field' in self.params.target:
+                valid_weighted_rmse += weighted_rmse_torch((gen + inp), (tar + inp))
+            else:
+                valid_weighted_rmse += weighted_rmse_torch(gen, tar)
+
+
+        if not self.precip:
+            try:
+                os.mkdir(params['experiment_dir'] + "/" + str(i))
+            except:
+                pass
+            #save first channel of image
+            if self.params.two_step_training:
+                save_image(torch.cat((gen_step_one[0,0], torch.zeros((self.valid_dataset.img_shape_x, 4)).to(self.device, dtype = torch.float), tar[0,0]), axis = 1), params['experiment_dir'] + "/" + str(i) + "/" + str(self.epoch) + ".png")
+            else:
+                save_image(torch.cat((gen[0,0], torch.zeros((self.valid_dataset.img_shape_x, 4)).to(self.device, dtype = torch.float), tar[0,0]), axis = 1), params['experiment_dir'] + "/" + str(i) + "/" + str(self.epoch) + ".png")
+
+           
+    if dist.is_initialized():
+      dist.all_reduce(valid_buff)
+      dist.all_reduce(valid_weighted_rmse)
+
+    # divide by number of steps
+    valid_buff[0:2] = valid_buff[0:2] / valid_buff[2]
+    valid_weighted_rmse = valid_weighted_rmse / valid_buff[2]
+    if not self.precip:
+      valid_weighted_rmse *= mult
+
+    # download buffers
+    valid_buff_cpu = valid_buff.detach().cpu().numpy()
+    valid_weighted_rmse_cpu = valid_weighted_rmse.detach().cpu().numpy()
+
+    valid_time = time.time() - valid_start
+    valid_weighted_rmse = mult*torch.mean(valid_weighted_rmse, axis = 0)
+    if self.precip:
+      logs = {'valid_l1': valid_buff_cpu[1], 'valid_loss': valid_buff_cpu[0], 'valid_rmse_tp': valid_weighted_rmse_cpu[0]}
+    else:
+      try:
+        logs = {'valid_l1': valid_buff_cpu[1], 'valid_loss': valid_buff_cpu[0], 'valid_rmse_u10': valid_weighted_rmse_cpu[0], 'valid_rmse_v10': valid_weighted_rmse_cpu[1]}
+      except:
+        logs = {'valid_l1': valid_buff_cpu[1], 'valid_loss': valid_buff_cpu[0], 'valid_rmse_u10': valid_weighted_rmse_cpu[0]}#, 'valid_rmse_v10': valid_weighted_rmse[1]}
+    
+    if self.params.log_to_wandb:
+      if self.precip:
+        fig = vis_precip(fields)
+        logs['vis'] = wandb.Image(fig)
+        plt.close(fig)
+      wandb.log(logs, step=self.epoch)
+
+    return valid_time, logs
+
+  def validate_final(self):
+    self.model.eval()
+    n_valid_batches = int(self.valid_dataset.n_patches_total/self.valid_dataset.n_patches) #validate on whole dataset
+    valid_weighted_rmse = torch.zeros(n_valid_batches, self.params.N_out_channels)
+    if self.params.normalization == 'minmax':
+        raise Exception("minmax normalization not supported")
+    elif self.params.normalization == 'zscore':
+        mult = torch.as_tensor(np.load(self.params.global_stds_path)[0, self.params.out_channels, 0, 0]).to(self.device)
+
+    with torch.no_grad():
+        for i, data in enumerate(self.valid_data_loader):
+          if i>100:
+            break
+          inp, tar = map(lambda x: x.to(self.device, dtype = torch.float), data)
+          if self.params.orography and self.params.two_step_training:
+              orog = inp[:,-2:-1]
+          if 'residual_field' in self.params.target:
+            tar -= inp[:, 0:tar.size()[1]]
+
+        if self.params.two_step_training:
+            gen_step_one = self.model(inp).to(self.device, dtype = torch.float)
+            loss_step_one = self.loss_obj(gen_step_one, tar[:,0:self.params.N_out_channels])
+
+            if self.params.orography:
+                gen_step_two = self.model(torch.cat( (gen_step_one, orog), axis = 1)  ).to(self.device, dtype = torch.float)
+            else:
+                gen_step_two = self.model(gen_step_one).to(self.device, dtype = torch.float)
+
+            loss_step_two = self.loss_obj(gen_step_two, tar[:,self.params.N_out_channels:2*self.params.N_out_channels])
+            valid_loss[i] = loss_step_one + loss_step_two
+            valid_l1[i] = nn.functional.l1_loss(gen_step_one, tar[:,0:self.params.N_out_channels])
+        else:
+            gen = self.model(inp)
+            valid_loss[i] += self.loss_obj(gen, tar) 
+            valid_l1[i] += nn.functional.l1_loss(gen, tar)
+
+        if self.params.two_step_training:
+            for c in range(self.params.N_out_channels):
+              if 'residual_field' in self.params.target:
+                valid_weighted_rmse[i, c] = weighted_rmse_torch((gen_step_one[0,c] + inp[0,c]), (tar[0,c]+inp[0,c]), self.device)
+              else:
+                valid_weighted_rmse[i, c] = weighted_rmse_torch(gen_step_one[0,c], tar[0,c], self.device)
+        else:
+            for c in range(self.params.N_out_channels):
+              if 'residual_field' in self.params.target:
+                valid_weighted_rmse[i, c] = weighted_rmse_torch((gen[0,c] + inp[0,c]), (tar[0,c]+inp[0,c]), self.device)
+              else:
+                valid_weighted_rmse[i, c] = weighted_rmse_torch(gen[0,c], tar[0,c], self.device)
+            
+        #un-normalize
+        valid_weighted_rmse = mult*torch.mean(valid_weighted_rmse[0:100], axis = 0).to(self.device)
+
+    return valid_weighted_rmse
+
+
+  def load_model_wind(self, model_path):
+    if self.params.log_to_screen:
+      logging.info('Loading the wind model weights from {}'.format(model_path))
+    checkpoint = torch.load(model_path, map_location='cuda:{}'.format(self.params.local_rank))
+    if dist.is_initialized():
+      self.model_wind.load_state_dict(checkpoint['model_state'])
+    else:
+      new_model_state = OrderedDict()
+      model_key = 'model_state' if 'model_state' in checkpoint else 'state_dict'
+      for key in checkpoint[model_key].keys():
+          if 'module.' in key: # model was stored using ddp which prepends module
+              name = str(key[7:])
+              new_model_state[name] = checkpoint[model_key][key]
+          else:
+              new_model_state[key] = checkpoint[model_key][key]
+      self.model_wind.load_state_dict(new_model_state)
+      self.model_wind.eval()
+
+  def save_checkpoint(self, checkpoint_path, model=None):
+    """ We intentionally require a checkpoint_dir to be passed
+        in order to allow Ray Tune to use this function """
+
+    if not model:
+      model = self.model
+
+    torch.save({'iters': self.iters, 'epoch': self.epoch, 'model_state': model.state_dict(),
+                  'optimizer_state_dict': self.optimizer.state_dict()}, checkpoint_path)
+
+  def restore_checkpoint(self, checkpoint_path):
+    """ We intentionally require a checkpoint_dir to be passed
+        in order to allow Ray Tune to use this function """
+    checkpoint = torch.load(checkpoint_path, map_location='cuda:{}'.format(self.params.local_rank))
+    try:
+        self.model.load_state_dict(checkpoint['model_state'])
+    except:
+        new_state_dict = OrderedDict()
+        for key, val in checkpoint['model_state'].items():
+            name = key[7:]
+            new_state_dict[name] = val 
+        self.model.load_state_dict(new_state_dict)
+    self.iters = checkpoint['iters']
+    self.startEpoch = checkpoint['epoch']
+    if self.params.resuming:  #restore checkpoint is used for finetuning as well as resuming. If finetuning (i.e., not resuming), restore checkpoint does not load optimizer state, instead uses config specified lr.
+        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
+
+if __name__ == '__main__':
+  parser = argparse.ArgumentParser()
+  parser.add_argument("--run_num", default='00', type=str)
+  parser.add_argument("--yaml_config", default='./config/AFNO.bak.yaml', type=str)
+  parser.add_argument("--config", default='default', type=str)
+  parser.add_argument("--enable_amp", action='store_true')
+  parser.add_argument("--epsilon_factor", default = 0, type = float)
+
+  args = parser.parse_args()
+
+  params = YParams(os.path.abspath(args.yaml_config), args.config)
+  params['epsilon_factor'] = args.epsilon_factor
+
+  params['world_size'] = 1
+  if 'WORLD_SIZE' in os.environ:
+    params['world_size'] = int(os.environ['WORLD_SIZE'])
+
+  world_rank = 0
+  local_rank = 0
+  if params['world_size'] > 1:
+    dist.init_process_group(backend='nccl',
+                            init_method='env://')
+    local_rank = int(os.environ["LOCAL_RANK"])
+    args.gpu = local_rank
+    world_rank = dist.get_rank()
+    params['global_batch_size'] = params.batch_size
+    params['batch_size'] = int(params.batch_size//params['world_size'])
+
+  torch.cuda.set_device(local_rank)
+  torch.backends.cudnn.benchmark = True
+
+  # Set up directory
+  expDir = os.path.join(params.exp_dir, args.config, str(args.run_num))
+  if  world_rank==0:
+    if not os.path.isdir(expDir):
+      os.makedirs(expDir)
+      os.makedirs(os.path.join(expDir, 'training_checkpoints/'))
+
+  params['experiment_dir'] = os.path.abspath(expDir)
+  params['checkpoint_path'] = os.path.join(expDir, 'training_checkpoints/ckpt.tar')
+  params['best_checkpoint_path'] = os.path.join(expDir, 'training_checkpoints/best_ckpt.tar')
+
+  # Do not comment this line out please:
+  args.resuming = True if os.path.isfile(params.checkpoint_path) else False
+
+  params['resuming'] = args.resuming
+  params['local_rank'] = local_rank
+  params['enable_amp'] = args.enable_amp
+
+  # this will be the wandb name
+#  params['name'] = args.config + '_' + str(args.run_num)
+#  params['group'] = "era5_wind" + args.config
+  params['name'] = args.config + '_' + str(args.run_num)
+  params['group'] = "era5_precip" + args.config
+  params['project'] = "ERA5_precip"
+  params['entity'] = "flowgan"
+  if world_rank==0:
+    logging_utils.log_to_file(logger_name=None, log_filename=os.path.join(expDir, 'out.log'))
+    logging_utils.log_versions()
+    params.log()
+
+  params['log_to_wandb'] = (world_rank==0) and params['log_to_wandb']
+  params['log_to_screen'] = (world_rank==0) and params['log_to_screen']
+
+  params['in_channels'] = np.array(params['in_channels'])
+  params['out_channels'] = np.array(params['out_channels'])
+  if params.orography:
+    params['N_in_channels'] = len(params['in_channels']) +1
+  else:
+    params['N_in_channels'] = len(params['in_channels'])
+
+  params['N_out_channels'] = len(params['out_channels'])
+
+  if world_rank == 0:
+    hparams = ruamelDict()
+    yaml = YAML()
+    for key, value in params.params.items():
+      hparams[str(key)] = str(value)
+    with open(os.path.join(expDir, 'hyperparams.yaml'), 'w') as hpfile:
+      yaml.dump(hparams,  hpfile )
+
+  trainer = Trainer(params, world_rank)
+  trainer.train()
+  logging.info('DONE ---- rank %d'%world_rank)
-- 
GitLab