Skip to content
Snippets Groups Projects
AFNO.yaml 2.52 KiB
Newer Older
xuetaowave's avatar
xuetaowave committed
### base config ###
full_field: &FULL_FIELD
  loss: 'l2'
  lr: 1E-3
  scheduler: 'ReduceLROnPlateau'
  retrain: !!bool False
  num_data_workers: 4
  dt: 1 # how many timesteps ahead the model will predict
  n_history: 0 #how many previous timesteps to consider
  prediction_type: 'iterative'
  prediction_length: 41 #applicable only if prediction_type == 'iterative'
  n_initial_conditions: 5 #applicable only if prediction_type == 'iterative'
  ics_type: "default"
  save_raw_forecasts: !!bool True
  save_channel: !!bool False
  masked_acc: !!bool False
  maskpath: None
  perturb: !!bool False
  add_grid: !!bool False
  N_grid_channels: 0
  gridtype: 'sinusoidal' #options 'sinusoidal' or 'linear'
  roll: !!bool False
  max_epochs: 50
  batch_size: 64

  #afno hyperparams
  num_blocks: 8
  nettype: 'afno'
  patch_size: 8
  width: 56
  modes: 32
  #options default, residual
  target: 'default' 
  in_channels: [0,1]
  out_channels: [0,1] #must be same as in_channels if prediction_type == 'iterative'
  normalization: 'zscore' #options zscore (minmax not supported) 
  train_data_path: '/pscratch/sd/j/jpathak/wind/train'
  valid_data_path: '/pscratch/sd/j/jpathak/wind/test'
  inf_data_path: '/pscratch/sd/j/jpathak/wind/out_of_sample' # test set path for inference
  exp_dir: '/pscratch/sd/j/jpathak/ERA5_expts_gtc/wind'
  time_means_path:   '/pscratch/sd/j/jpathak/wind/time_means.npy'
  global_means_path: '/pscratch/sd/j/jpathak/wind/global_means.npy'
  global_stds_path:  '/pscratch/sd/j/jpathak/wind/global_stds.npy'

  orography: !!bool False
  orography_path: None

  log_to_screen: !!bool True
  log_to_wandb: !!bool True
  save_checkpoint: !!bool True

  enable_nhwc: !!bool False
  optimizer_type: 'FusedAdam'
  crop_size_x: None
  crop_size_y: None

  two_step_training: !!bool False
  plot_animations: !!bool False

  add_noise: !!bool False
  noise_std: 0

afno_backbone_ljkj: &LJKJ
  <<: *FULL_FIELD
  log_to_wandb: !!bool False
  lr: 5E-4
  batch_size: 8
  patch_size: 4
  depth : 4 # default 12
xuetaowave's avatar
xuetaowave committed
  img_size: [128, 256]
xuetaowave's avatar
xuetaowave committed
  max_epochs: 1500
  scheduler: 'CosineAnnealingLR'
  in_channels: [0, 1 ,2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]
  out_channels: [0, 1 ,2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]
  prediction_length: 100
  orography: !!bool False
  orography_path: None
  exp_dir: './results/tec_256'
  train_data_path: './train'
  valid_data_path: './test'
  inf_data_path: './out_of_sample'
  time_means_path: './time_means.npy'
  global_means_path: './global_means.npy'
  global_stds_path: './global_stds.npy'

afno_backbone_ustc:
  <<: *LJKJ
  batch_size: 64