Skip to content
Snippets Groups Projects
Commit 5309c6e9 authored by xuetao chen's avatar xuetao chen
Browse files

update

parent 4174b173
No related branches found
No related tags found
No related merge requests found
### base config ###
full_field: &FULL_FIELD
loss: 'l2'
lr: 1E-3
scheduler: 'ReduceLROnPlateau'
retrain: !!bool False
num_data_workers: 4
dt: 1 # how many timesteps ahead the model will predict
n_history: 0 #how many previous timesteps to consider
prediction_type: 'iterative'
prediction_length: 41 #applicable only if prediction_type == 'iterative'
n_initial_conditions: 5 #applicable only if prediction_type == 'iterative'
ics_type: "default"
save_raw_forecasts: !!bool True
save_channel: !!bool False
masked_acc: !!bool False
maskpath: None
perturb: !!bool False
add_grid: !!bool False
N_grid_channels: 0
gridtype: 'sinusoidal' #options 'sinusoidal' or 'linear'
roll: !!bool False
max_epochs: 50
batch_size: 64
#afno hyperparams
num_blocks: 8
nettype: 'afno'
patch_size: 8
width: 56
modes: 32
#options default, residual
target: 'default'
in_channels: [0,1]
out_channels: [0,1] #must be same as in_channels if prediction_type == 'iterative'
normalization: 'zscore' #options zscore (minmax not supported)
train_data_path: '/pscratch/sd/j/jpathak/wind/train'
valid_data_path: '/pscratch/sd/j/jpathak/wind/test'
inf_data_path: '/pscratch/sd/j/jpathak/wind/out_of_sample' # test set path for inference
exp_dir: '/pscratch/sd/j/jpathak/ERA5_expts_gtc/wind'
time_means_path: '/pscratch/sd/j/jpathak/wind/time_means.npy'
global_means_path: '/pscratch/sd/j/jpathak/wind/global_means.npy'
global_stds_path: '/pscratch/sd/j/jpathak/wind/global_stds.npy'
orography: !!bool False
orography_path: None
log_to_screen: !!bool True
log_to_wandb: !!bool True
save_checkpoint: !!bool True
enable_nhwc: !!bool False
optimizer_type: 'FusedAdam'
crop_size_x: None
crop_size_y: None
two_step_training: !!bool False
plot_animations: !!bool False
add_noise: !!bool False
noise_std: 0
afno_backbone_ljkj: &LJKJ
<<: *FULL_FIELD
log_to_wandb: !!bool False
lr: 5E-4
batch_size: 2
patch_size: 2
depth : 6 # default 12
img_size: [192, 288]
max_epochs: 1500
scheduler: 'CosineAnnealingLR'
in_channels_range: [ 0, 162, 1 ]
out_channels_range: [ 0, 162, 1 ]
# in_channels: [0, 1 ,2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29]
# out_channels: [0, 1 ,2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29]
prediction_length: 100
orography: !!bool False
orography_path: None
exp_dir: './results/tec_256'
train_data_path: './train'
valid_data_path: './test'
inf_data_path: './out_of_sample'
time_means_path: './time_means.npy'
global_means_path: './global_means.npy'
global_stds_path: './global_stds.npy'
afno_backbone_ustc:
<<: *LJKJ
batch_size: 8
depth: 12
patch_size: 2
\ No newline at end of file
import numpy as np
import h5py
data_load = h5py.File('train/2010.h5', 'r')
data_load.keys()
data = data_load['fields']
global_mean = np.mean(data, axis=(0, 2, 3))[None, :, None, None]
print(global_mean.shape)
np.save('global_means.npy', global_mean)
global_stds = np.std(data, axis=(0, 2, 3))[None, :, None, None]
np.save('global_stds.npy', global_stds)
time_means = np.mean(data, axis=0)[None]
print(time_means.shape)
np.save('time_means.npy', time_means)
import numpy as np
from datetime import datetime, timedelta
from time import time
import os
import platform
os.getcwd()
start_date = datetime(2010, 1, 11)
end_date = datetime(2010, 12, 31)
if 'WSL' in platform.release():
path_list = {
'Density':'/mnt/h/work/Data/WACCM/',
}
else:
path_list = {
'Density': '/home/ess/cxt/work/data/WACCM/',
}
kind = 'Density'
value = []
date_list = []
for i in range((end_date-start_date).days):
t0 = time()
c_date = start_date+timedelta(i)
filename = f'{path_list[kind]}/{c_date.year}/{kind}_height/FWSD_2010.cam.h1.' + c_date.strftime(
'%Y-%m-%d') + f'-00000.{kind}.2-5km.npz'
seconds = np.arange(0, 24*60*60, 60*30)
date_i = np.array([c_date+timedelta(0, int(seconds[i])) for i in range(seconds.size)])
if os.path.exists(filename) is True:
with np.load(filename) as data_load:
value.append(np.array(data_load['value'][1::2].astype('float32')))
date_list.append(date_i[1::2])
t1 = time()
print('time usage:{} finish:{}'.format(t1-t0, filename))
else:
print('not found: {}'.format(filename))
value_all = np.concatenate(value, axis=0)
np.savez(kind, fields=value_all, date=np.array(date_list))
import numpy as np
import h5py
data_load_U = np.load('Density.npz')
U = data_load_U['fields']
n_levels = 27
data = U.reshape((-1, 6*n_levels)+U.shape[2:])
train_mask = data.shape[0]//6*5
import os
print(os.getcwd())
os.makedirs('train', exist_ok=True)
os.makedirs('test', exist_ok=True)
os.makedirs('out_of_sample', exist_ok=True)
import h5py
with h5py.File('train/2010.h5', 'w') as f:
f.create_dataset("fields", data=data[:train_mask])
with h5py.File('test/2010.h5', 'w') as f:
f.create_dataset("fields", data=data[train_mask:])
with h5py.File('out_of_sample/2010.h5', 'w') as f:
f.create_dataset("fields", data=data[train_mask:])
#!/bin/bash -l
#SBATCH --time=24:00:00
#SBATCH -J afno
#SBATCH -o afno_backbone_finetune.out
#SBATCH -N 1 -n 1 -c 8 --gres=gpu:a100:1 -p GPU-8A100 --qos=gpu_8a100
config_file=./AFNO.yaml
config='afno_backbone_ustc'
run_num='d12p3'
export HDF5_USE_FILE_LOCKING=FALSE
export NCCL_NET_GDR_LEVEL=PHB
export MASTER_ADDR=$(hostname)
set -x
srun -u --mpi=pmi2 \
bash -c "
source /home/ess/cxt/miniconda3/etc/profile.d/conda.sh
conda activate pytorch
python ../train.py --enable_amp --yaml_config=$config_file --config=$config --run_num=$run_num
"
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment