Weatherbench data
This ClimaX model example uses the weatherbench benchmark dataset. The dataset is processed and made ready for Climax training, and is in our dk92 project located under '/g/data/dk92/apps/climax/weatherbench/5.625deg_npz'. The train, test, and validation datasets are located in separate folders.
It consists of the following files and folders:
├── lat.npy
├── lon.npy
├── normalize_mean.npz
├── normalize_std.npz
├── test
├── train
└── val
Training Script
The ClimaX training needs a large scale of computing resources and you can submit a PBS job to train the ClimaX model. ClimaX provides several python scripts to conduct training. They are available in "/g/data/dk92/apps/climax/0.2.3/src"
and the necessary configuration files are located under "/g/data/dk92/apps/climax/0.2.3/configs".
A shell script on fine tuning the ClimaX model with Weatherbench dataset is provided as "/g/data/dk92/apps/climax/0.2.3/examples/py_global_mn.sh". It is shown below
#!/usr/bin/env bash
export GPUS_PER_NODE=$((PBS_NGPUS / PBS_NNODES))
export MAX_EPOCHS=1
export OUTPUT_DIR="${PBS_O_WORKDIR}/climax_train_global_output"
export CONFIG_PATH="${CLIMAX_ROOT}/configs/global_forecast_climax.yaml"
export ROOT_DIR='/g/data/dk92/apps/climax/weatherbench/5.625deg_npz'
export PRETRAIN_PATH='/g/data/dk92/apps/climax/weatherbench/ClimaX-5.625deg.ckpt'
export PROG_BAR='True'
HOST=`hostname`
RANK=`cat $PBS_NODEFILE | uniq | grep -n $HOST | cut -f1 -d :`
export NODE_RANK=$((RANK-1))
echo "=========================================================="
echo "RANK=${RANK} HOST=${HOST}"
echo "MAX_EPOCH=${MAX_EPOCHS}"
echo "OUTPUT_DIR=${OUTPUT_DIR}"
echo "CONFIG_PATH=${CONFIG_PATH}"
echo "ROOT_DIR=${ROOT_DIR}"
echo "PRETRAIN_PATH=${PRETRAIN_PATH}"
echo "=========================================================="
python \
${CLIMAX_ROOT}/src/climax/global_forecast/train.py \
--config ${CONFIG_PATH} \
--trainer.num_nodes=${PBS_NNODES} \
--trainer.strategy=ddp --trainer.devices=${GPUS_PER_NODE} \
--trainer.max_epochs=${MAX_EPOCHS} \
--trainer.enable_progress_bar=${PROG_BAR} \
--data.root_dir=${ROOT_DIR} \
--data.predict_range=72 \
--data.out_variables=['geopotential_500','temperature_850','2m_temperature'] \
--data.batch_size=16 \
--model.pretrained_path=${PRETRAIN_PATH} \
--model.lr=5e-7 --model.beta_1="0.9" --model.beta_2="0.99" \
--model.weight_decay=1e-5
You need to copy it to your own work directory and make necessary changes on those model parameters and flags. After that you can put it in a PBS job requesting various GPU resources.
Single Node
You can run "py_global_mn.sh" directly in a job script requesting single GPU node resources with multiple GPU devices. Please add your own project ID in the PBS storage list.
[rxy900@gadi-login-07 examples]$ more pl_job.pbs
#!/bin/bash
#PBS -q gpuvolta
#PBS -l ncpus=48
#PBS -l ngpus=4
#PBS -l jobfs=200GB
#PBS -l storage=gdata/dk92+gdata/YOUR_PROJECT
#PBS -l mem=380GB
#PBS -l walltime=02:00:00
#PBS -l wd
#PBS -N Gl_G_8_s_20
module use /g/data/dk92/apps/Modules/modulefiles/
module load climax/0.2.3
./pl_global_mn.sh
The output will show informations about distributing training such as local rank, global rank, distributed backend etc.
GPU available: True (cuda), used: True TPU available: False, using: 0 TPU cores IPU available: False, using: 0 IPUs HPU available: False, using: 0 HPUs
Initializing distributed: GLOBAL_RANK: 1, MEMBER: 2/4 Initializing distributed: GLOBAL_RANK: 2, MEMBER: 3/4 Initializing distributed: GLOBAL_RANK: 3, MEMBER: 4/4 Initializing distributed: GLOBAL_RANK: 0, MEMBER: 1/4 ---------------------------------------------------------------------------------------------------- distributed_backend=nccl All distributed processes registered. Starting with 4 processes ----------------------------------------------------------------------------------------------------
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3] LOCAL_RANK: 3 - CUDA_VISIBLE_DEVICES: [0,1,2,3] LOCAL_RANK: 2 - CUDA_VISIBLE_DEVICES: [0,1,2,3] LOCAL_RANK: 1 - CUDA_VISIBLE_DEVICES: [0,1,2,3]
| Name | Type | Params ---------------------------------------------- 0 | net | ClimaX | 108 M 1 | denormalization | Normalize | 0 ---------------------------------------------- 108 M Trainable params 0 Non-trainable params 108 M Total params 216.177 Total estimated model params size (MB) ... |
Multiple Nodes
You can use "mpirun" to run the training script across multiple GPU nodes. An example PBS job script requesting 2 GPU nodes is given in /g/data/dk92/apps/climax/0.2.3/examples/mpi_job.pbs and it is shown below
#!/bin/bash
#PBS -q gpuvolta
#PBS -l ncpus=96
#PBS -l ngpus=8
#PBS -l jobfs=800GB
#PBS -l storage=gdata/dk92+gdata/YOUR_PROJECT
#PBS -l mem=760GB
#PBS -l walltime=02:00:00
#PBS -l wd
#PBS -N Gl_G_8_s_20
module use /g/data/dk92/apps/Modules/modulefiles/
module load climax/0.2.3
export WORLD_SIZE=$PBS_NGPUS
export MASTER_ADDR=$(head -n 1 $PBS_NODEFILE | uniq )
export MASTER_PORT=10002
mpirun -np ${PBS_NNODES} --bind-to none --map-by node py_global_mn.sh
The output from different nodes shows some key configuration features for distributed training. Eight nodes are initialized with the NCCL backend, which is a high throughput GPU interconnect framework. Each of the two nodes has four nodes, ranked from 0 to 3.
Using 16bit native Automatic Mixed Precision (AMP)
Using 16bit native Automatic Mixed Precision (AMP)
...
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
...
Initializing distributed: GLOBAL_RANK: 1 , MEMBER: 2 / 8
Initializing distributed: GLOBAL_RANK: 5 , MEMBER: 6 / 8
Initializing distributed: GLOBAL_RANK: 6 , MEMBER: 7 / 8
Initializing distributed: GLOBAL_RANK: 2 , MEMBER: 3 / 8
Initializing distributed: GLOBAL_RANK: 0 , MEMBER: 1 / 8
Initializing distributed: GLOBAL_RANK: 4 , MEMBER: 5 / 8
Initializing distributed: GLOBAL_RANK: 3 , MEMBER: 4 / 8
Initializing distributed: GLOBAL_RANK: 7 , MEMBER: 8 / 8
----------------------------------------------------------------------------------------------------
distributed_backend=nccl
All distributed processes registered. Starting with 8 processes
----------------------------------------------------------------------------------------------------
...
LOCAL_RANK: 2 - CUDA_VISIBLE_DEVICES: [ 0 , 1 , 2 , 3 ]
LOCAL_RANK: 1 - CUDA_VISIBLE_DEVICES: [ 0 , 1 , 2 , 3 ]
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [ 0 , 1 , 2 , 3 ]
LOCAL_RANK: 3 - CUDA_VISIBLE_DEVICES: [ 0 , 1 , 2 , 3 ]
...
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [ 0 , 1 , 2 , 3 ]
LOCAL_RANK: 2 - CUDA_VISIBLE_DEVICES: [ 0 , 1 , 2 , 3 ]
LOCAL_RANK: 1 - CUDA_VISIBLE_DEVICES: [ 0 , 1 , 2 , 3 ]
LOCAL_RANK: 3 - CUDA_VISIBLE_DEVICES: [ 0 , 1 , 2 , 3 ]
|
The approximate training time in this example for one epoch on the global Weatherbench `5.625 deg.` resolution data using ClimaX is listed below:
V100 GPU(s) | Approx. Epoch Time (min) |
---|
1 | 117.3 |
4 | 34.5 |
8 | 17.5 |
Pytorch Lightning wrapper
NCI provide an alternative solution to run the ClimaX training script which can be be put into a a PBS job requesting arbitary scales of resources, i.e. from single node to multiple nodes. You can simply replace "python" in "py_global_mn.sh" with "pl_python" as shown in "/g/data/dk92/apps/climax/0.2.3/examples/pl_global_mn.sh". An example PBS job script on running it is shown below and you can find it as "/g/data/dk92/apps/climax/0.2.3/examples/pl_job.pbs".
#!/bin/bash
#PBS -q gpuvolta
#PBS -l ncpus=96
#PBS -l ngpus=8
#PBS -l jobfs=800GB
#PBS -l storage=gdata/dk92+gdata/YOUR_PROJECT
#PBS -l mem=760GB
#PBS -l walltime=02:00:00
#PBS -l wd
#PBS -N Gl_G_8_s_20
module use /g/data/dk92/apps/Modules/modulefiles/
module load climax/0.2.3
./pl_global_mn.sh
It will produce outputs for each worker nodes.
Model Parameters
The ClimaX model has about 300 layers and 108 million parameters in total. Names of the layers and the number of parameters are listed below:
┏━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━┓ ┃ ┃ Name ┃ Type ┃ Params ┃ ┡━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━┩ │ 0 │ net │ RegionalClimaX │ 108 M │ │ 1 │ net.token_embeds │ ModuleList │ 245 K │ │ 2 │ net.token_embeds.0 │ PatchEmbed │ 5.1 K │ │ 3 │ net.token_embeds.0.proj │ Conv2d │ 5.1 K │ │ 4 │ net.token_embeds.0.norm │ Identity │ 0 │ │ 5 │ net.token_embeds.1 │ PatchEmbed │ 5.1 K │ │ 6 │ net.token_embeds.1.proj │ Conv2d │ 5.1 K │ │ 7 │ net.token_embeds.1.norm │ Identity │ 0 │ │ 8 │ net.token_embeds.2 │ PatchEmbed │ 5.1 K │ │ 9 │ net.token_embeds.2.proj │ Conv2d │ 5.1 K │ │ 10 │ net.token_embeds.2.norm │ Identity │ 0 │ │ 11 │ net.token_embeds.3 │ PatchEmbed │ 5.1 K │ │ 12 │ net.token_embeds.3.proj │ Conv2d │ 5.1 K │ │ 13 │ net.token_embeds.3.norm │ Identity │ 0 │ │ 14 │ net.token_embeds.4 │ PatchEmbed │ 5.1 K │ │ 15 │ net.token_embeds.4.proj │ Conv2d │ 5.1 K │ │ 16 │ net.token_embeds.4.norm │ Identity │ 0 │ │ 17 │ net.token_embeds.5 │ PatchEmbed │ 5.1 K │ │ 18 │ net.token_embeds.5.proj │ Conv2d │ 5.1 K │ │ 19 │ net.token_embeds.5.norm │ Identity │ 0 │ │ 20 │ net.token_embeds.6 │ PatchEmbed │ 5.1 K │ │ 21 │ net.token_embeds.6.proj │ Conv2d │ 5.1 K │ │ 22 │ net.token_embeds.6.norm │ Identity │ 0 │ │ 23 │ net.token_embeds.7 │ PatchEmbed │ 5.1 K │ │ 24 │ net.token_embeds.7.proj │ Conv2d │ 5.1 K │ │ 25 │ net.token_embeds.7.norm │ Identity │ 0 │ │ 26 │ net.token_embeds.8 │ PatchEmbed │ 5.1 K │ │ 27 │ net.token_embeds.8.proj │ Conv2d │ 5.1 K │ │ 28 │ net.token_embeds.8.norm │ Identity │ 0 │ │ 29 │ net.token_embeds.9 │ PatchEmbed │ 5.1 K │ │ 30 │ net.token_embeds.9.proj │ Conv2d │ 5.1 K │ │ 31 │ net.token_embeds.9.norm │ Identity │ 0 │ │ 32 │ net.token_embeds.10 │ PatchEmbed │ 5.1 K │ │ 33 │ net.token_embeds.10.proj │ Conv2d │ 5.1 K │ │ 34 │ net.token_embeds.10.norm │ Identity │ 0 │ │ 35 │ net.token_embeds.11 │ PatchEmbed │ 5.1 K │ │ 36 │ net.token_embeds.11.proj │ Conv2d │ 5.1 K │ │ 37 │ net.token_embeds.11.norm │ Identity │ 0 │ │ 38 │ net.token_embeds.12 │ PatchEmbed │ 5.1 K │ │ 39 │ net.token_embeds.12.proj │ Conv2d │ 5.1 K │ │ 40 │ net.token_embeds.12.norm │ Identity │ 0 │ │ 41 │ net.token_embeds.13 │ PatchEmbed │ 5.1 K │ │ 42 │ net.token_embeds.13.proj │ Conv2d │ 5.1 K │ │ 43 │ net.token_embeds.13.norm │ Identity │ 0 │ │ 44 │ net.token_embeds.14 │ PatchEmbed │ 5.1 K │ │ 45 │ net.token_embeds.14.proj │ Conv2d │ 5.1 K │ │ 46 │ net.token_embeds.14.norm │ Identity │ 0 │ │ 47 │ net.token_embeds.15 │ PatchEmbed │ 5.1 K │ │ 48 │ net.token_embeds.15.proj │ Conv2d │ 5.1 K │ │ 49 │ net.token_embeds.15.norm │ Identity │ 0 │ │ 50 │ net.token_embeds.16 │ PatchEmbed │ 5.1 K │ │ 51 │ net.token_embeds.16.proj │ Conv2d │ 5.1 K │ │ 52 │ net.token_embeds.16.norm │ Identity │ 0 │ │ 53 │ net.token_embeds.17 │ PatchEmbed │ 5.1 K │ │ 54 │ net.token_embeds.17.proj │ Conv2d │ 5.1 K │ │ 55 │ net.token_embeds.17.norm │ Identity │ 0 │ │ 56 │ net.token_embeds.18 │ PatchEmbed │ 5.1 K │ │ 57 │ net.token_embeds.18.proj │ Conv2d │ 5.1 K │ │ 58 │ net.token_embeds.18.norm │ Identity │ 0 │ │ 59 │ net.token_embeds.19 │ PatchEmbed │ 5.1 K │ │ 60 │ net.token_embeds.19.proj │ Conv2d │ 5.1 K │ │ 61 │ net.token_embeds.19.norm │ Identity │ 0 │ │ 62 │ net.token_embeds.20 │ PatchEmbed │ 5.1 K │ │ 63 │ net.token_embeds.20.proj │ Conv2d │ 5.1 K │ │ 64 │ net.token_embeds.20.norm │ Identity │ 0 │ │ 65 │ net.token_embeds.21 │ PatchEmbed │ 5.1 K │ │ 66 │ net.token_embeds.21.proj │ Conv2d │ 5.1 K │ │ 67 │ net.token_embeds.21.norm │ Identity │ 0 │ │ 68 │ net.token_embeds.22 │ PatchEmbed │ 5.1 K │ │ 69 │ net.token_embeds.22.proj │ Conv2d │ 5.1 K │ │ 70 │ net.token_embeds.22.norm │ Identity │ 0 │ │ 71 │ net.token_embeds.23 │ PatchEmbed │ 5.1 K │ │ 72 │ net.token_embeds.23.proj │ Conv2d │ 5.1 K │ │ 73 │ net.token_embeds.23.norm │ Identity │ 0 │ │ 74 │ net.token_embeds.24 │ PatchEmbed │ 5.1 K │ │ 75 │ net.token_embeds.24.proj │ Conv2d │ 5.1 K │ │ 76 │ net.token_embeds.24.norm │ Identity │ 0 │ │ 77 │ net.token_embeds.25 │ PatchEmbed │ 5.1 K │ │ 78 │ net.token_embeds.25.proj │ Conv2d │ 5.1 K │ │ 79 │ net.token_embeds.25.norm │ Identity │ 0 │ │ 80 │ net.token_embeds.26 │ PatchEmbed │ 5.1 K │ │ 81 │ net.token_embeds.26.proj │ Conv2d │ 5.1 K │ │ 82 │ net.token_embeds.26.norm │ Identity │ 0 │ │ 83 │ net.token_embeds.27 │ PatchEmbed │ 5.1 K │ │ 84 │ net.token_embeds.27.proj │ Conv2d │ 5.1 K │ │ 85 │ net.token_embeds.27.norm │ Identity │ 0 │ │ 86 │ net.token_embeds.28 │ PatchEmbed │ 5.1 K │ │ 87 │ net.token_embeds.28.proj │ Conv2d │ 5.1 K │ │ 88 │ net.token_embeds.28.norm │ Identity │ 0 │ │ 89 │ net.token_embeds.29 │ PatchEmbed │ 5.1 K │ │ 90 │ net.token_embeds.29.proj │ Conv2d │ 5.1 K │ │ 91 │ net.token_embeds.29.norm │ Identity │ 0 │ │ 92 │ net.token_embeds.30 │ PatchEmbed │ 5.1 K │ │ 93 │ net.token_embeds.30.proj │ Conv2d │ 5.1 K │ │ 94 │ net.token_embeds.30.norm │ Identity │ 0 │ │ 95 │ net.token_embeds.31 │ PatchEmbed │ 5.1 K │ │ 96 │ net.token_embeds.31.proj │ Conv2d │ 5.1 K │ │ 97 │ net.token_embeds.31.norm │ Identity │ 0 │ │ 98 │ net.token_embeds.32 │ PatchEmbed │ 5.1 K │ │ 99 │ net.token_embeds.32.proj │ Conv2d │ 5.1 K │ │ 100 │ net.token_embeds.32.norm │ Identity │ 0 │ │ 101 │ net.token_embeds.33 │ PatchEmbed │ 5.1 K │ │ 102 │ net.token_embeds.33.proj │ Conv2d │ 5.1 K │ │ 103 │ net.token_embeds.33.norm │ Identity │ 0 │ │ 104 │ net.token_embeds.34 │ PatchEmbed │ 5.1 K │ │ 105 │ net.token_embeds.34.proj │ Conv2d │ 5.1 K │ │ 106 │ net.token_embeds.34.norm │ Identity │ 0 │ │ 107 │ net.token_embeds.35 │ PatchEmbed │ 5.1 K │ │ 108 │ net.token_embeds.35.proj │ Conv2d │ 5.1 K │ │ 109 │ net.token_embeds.35.norm │ Identity │ 0 │ │ 110 │ net.token_embeds.36 │ PatchEmbed │ 5.1 K │ │ 111 │ net.token_embeds.36.proj │ Conv2d │ 5.1 K │ │ 112 │ net.token_embeds.36.norm │ Identity │ 0 │ │ 113 │ net.token_embeds.37 │ PatchEmbed │ 5.1 K │ │ 114 │ net.token_embeds.37.proj │ Conv2d │ 5.1 K │ │ 115 │ net.token_embeds.37.norm │ Identity │ 0 │ │ 116 │ net.token_embeds.38 │ PatchEmbed │ 5.1 K │ │ 117 │ net.token_embeds.38.proj │ Conv2d │ 5.1 K │ │ 118 │ net.token_embeds.38.norm │ Identity │ 0 │ │ 119 │ net.token_embeds.39 │ PatchEmbed │ 5.1 K │ │ 120 │ net.token_embeds.39.proj │ Conv2d │ 5.1 K │ │ 121 │ net.token_embeds.39.norm │ Identity │ 0 │ │ 122 │ net.token_embeds.40 │ PatchEmbed │ 5.1 K │ │ 123 │ net.token_embeds.40.proj │ Conv2d │ 5.1 K │ │ 124 │ net.token_embeds.40.norm │ Identity │ 0 │ │ 125 │ net.token_embeds.41 │ PatchEmbed │ 5.1 K │ │ 126 │ net.token_embeds.41.proj │ Conv2d │ 5.1 K │ │ 127 │ net.token_embeds.41.norm │ Identity │ 0 │ │ 128 │ net.token_embeds.42 │ PatchEmbed │ 5.1 K │ │ 129 │ net.token_embeds.42.proj │ Conv2d │ 5.1 K │ │ 130 │ net.token_embeds.42.norm │ Identity │ 0 │ │ 131 │ net.token_embeds.43 │ PatchEmbed │ 5.1 K │ │ 132 │ net.token_embeds.43.proj │ Conv2d │ 5.1 K │ │ 133 │ net.token_embeds.43.norm │ Identity │ 0 │ │ 134 │ net.token_embeds.44 │ PatchEmbed │ 5.1 K │ │ 135 │ net.token_embeds.44.proj │ Conv2d │ 5.1 K │ │ 136 │ net.token_embeds.44.norm │ Identity │ 0 │ │ 137 │ net.token_embeds.45 │ PatchEmbed │ 5.1 K │ │ 138 │ net.token_embeds.45.proj │ Conv2d │ 5.1 K │ │ 139 │ net.token_embeds.45.norm │ Identity │ 0 │ │ 140 │ net.token_embeds.46 │ PatchEmbed │ 5.1 K │ │ 141 │ net.token_embeds.46.proj │ Conv2d │ 5.1 K │ │ 142 │ net.token_embeds.46.norm │ Identity │ 0 │ │ 143 │ net.token_embeds.47 │ PatchEmbed │ 5.1 K │ │ 144 │ net.token_embeds.47.proj │ Conv2d │ 5.1 K │ │ 145 │ net.token_embeds.47.norm │ Identity │ 0 │ │ 146 │ net.var_agg │ MultiheadAttention │ 4.2 M │ │ 147 │ net.var_agg.out_proj │ NonDynamicallyQuantizableLinear │ 1.0 M │ │ 148 │ net.lead_time_embed │ Linear │ 2.0 K │ │ 149 │ net.pos_drop │ Dropout │ 0 │ │ 150 │ net.blocks │ ModuleList │ 100 M │ │ 151 │ net.blocks.0 │ Block │ 12.6 M │ │ 152 │ net.blocks.0.norm1 │ LayerNorm │ 2.0 K │ │ 153 │ net.blocks.0.attn │ Attention │ 4.2 M │ │ 154 │ net.blocks.0.attn.qkv │ Linear │ 3.1 M │ │ 155 │ net.blocks.0.attn.attn_drop │ Dropout │ 0 │ │ 156 │ net.blocks.0.attn.proj │ Linear │ 1.0 M │ │ 157 │ net.blocks.0.attn.proj_drop │ Dropout │ 0 │ │ 158 │ net.blocks.0.ls1 │ Identity │ 0 │ │ 159 │ net.blocks.0.drop_path1 │ Identity │ 0 │ │ 160 │ net.blocks.0.norm2 │ LayerNorm │ 2.0 K │ │ 161 │ net.blocks.0.mlp │ Mlp │ 8.4 M │ │ 162 │ net.blocks.0.mlp.fc1 │ Linear │ 4.2 M │ │ 163 │ net.blocks.0.mlp.act │ GELU │ 0 │ │ 164 │ net.blocks.0.mlp.drop1 │ Dropout │ 0 │ │ 165 │ net.blocks.0.mlp.fc2 │ Linear │ 4.2 M │ │ 166 │ net.blocks.0.mlp.drop2 │ Dropout │ 0 │ │ 167 │ net.blocks.0.ls2 │ Identity │ 0 │ │ 168 │ net.blocks.0.drop_path2 │ Identity │ 0 │ │ 169 │ net.blocks.1 │ Block │ 12.6 M │ │ 170 │ net.blocks.1.norm1 │ LayerNorm │ 2.0 K │ │ 171 │ net.blocks.1.attn │ Attention │ 4.2 M │ │ 172 │ net.blocks.1.attn.qkv │ Linear │ 3.1 M │ │ 173 │ net.blocks.1.attn.attn_drop │ Dropout │ 0 │ │ 174 │ net.blocks.1.attn.proj │ Linear │ 1.0 M │ │ 175 │ net.blocks.1.attn.proj_drop │ Dropout │ 0 │ │ 176 │ net.blocks.1.ls1 │ Identity │ 0 │ │ 177 │ net.blocks.1.drop_path1 │ DropPath │ 0 │ │ 178 │ net.blocks.1.norm2 │ LayerNorm │ 2.0 K │ │ 179 │ net.blocks.1.mlp │ Mlp │ 8.4 M │ │ 180 │ net.blocks.1.mlp.fc1 │ Linear │ 4.2 M │ │ 181 │ net.blocks.1.mlp.act │ GELU │ 0 │ │ 182 │ net.blocks.1.mlp.drop1 │ Dropout │ 0 │ │ 183 │ net.blocks.1.mlp.fc2 │ Linear │ 4.2 M │ │ 184 │ net.blocks.1.mlp.drop2 │ Dropout │ 0 │ │ 185 │ net.blocks.1.ls2 │ Identity │ 0 │ │ 186 │ net.blocks.1.drop_path2 │ DropPath │ 0 │ │ 187 │ net.blocks.2 │ Block │ 12.6 M │ │ 188 │ net.blocks.2.norm1 │ LayerNorm │ 2.0 K │ │ 189 │ net.blocks.2.attn │ Attention │ 4.2 M │ │ 190 │ net.blocks.2.attn.qkv │ Linear │ 3.1 M │ │ 191 │ net.blocks.2.attn.attn_drop │ Dropout │ 0 │ │ 192 │ net.blocks.2.attn.proj │ Linear │ 1.0 M │ │ 193 │ net.blocks.2.attn.proj_drop │ Dropout │ 0 │ │ 194 │ net.blocks.2.ls1 │ Identity │ 0 │ │ 195 │ net.blocks.2.drop_path1 │ DropPath │ 0 │ │ 196 │ net.blocks.2.norm2 │ LayerNorm │ 2.0 K │ │ 197 │ net.blocks.2.mlp │ Mlp │ 8.4 M │ │ 198 │ net.blocks.2.mlp.fc1 │ Linear │ 4.2 M │ │ 199 │ net.blocks.2.mlp.act │ GELU │ 0 │ │ 200 │ net.blocks.2.mlp.drop1 │ Dropout │ 0 │ │ 201 │ net.blocks.2.mlp.fc2 │ Linear │ 4.2 M │ │ 202 │ net.blocks.2.mlp.drop2 │ Dropout │ 0 │ │ 203 │ net.blocks.2.ls2 │ Identity │ 0 │ │ 204 │ net.blocks.2.drop_path2 │ DropPath │ 0 │ │ 205 │ net.blocks.3 │ Block │ 12.6 M │ │ 206 │ net.blocks.3.norm1 │ LayerNorm │ 2.0 K │ │ 207 │ net.blocks.3.attn │ Attention │ 4.2 M │ │ 208 │ net.blocks.3.attn.qkv │ Linear │ 3.1 M │ │ 209 │ net.blocks.3.attn.attn_drop │ Dropout │ 0 │ │ 210 │ net.blocks.3.attn.proj │ Linear │ 1.0 M │ │ 211 │ net.blocks.3.attn.proj_drop │ Dropout │ 0 │ │ 212 │ net.blocks.3.ls1 │ Identity │ 0 │ │ 213 │ net.blocks.3.drop_path1 │ DropPath │ 0 │ │ 214 │ net.blocks.3.norm2 │ LayerNorm │ 2.0 K │ │ 215 │ net.blocks.3.mlp │ Mlp │ 8.4 M │ │ 216 │ net.blocks.3.mlp.fc1 │ Linear │ 4.2 M │ │ 217 │ net.blocks.3.mlp.act │ GELU │ 0 │ │ 218 │ net.blocks.3.mlp.drop1 │ Dropout │ 0 │ │ 219 │ net.blocks.3.mlp.fc2 │ Linear │ 4.2 M │ │ 220 │ net.blocks.3.mlp.drop2 │ Dropout │ 0 │ │ 221 │ net.blocks.3.ls2 │ Identity │ 0 │ │ 222 │ net.blocks.3.drop_path2 │ DropPath │ 0 │ │ 223 │ net.blocks.4 │ Block │ 12.6 M │ │ 224 │ net.blocks.4.norm1 │ LayerNorm │ 2.0 K │ │ 225 │ net.blocks.4.attn │ Attention │ 4.2 M │ │ 226 │ net.blocks.4.attn.qkv │ Linear │ 3.1 M │ │ 227 │ net.blocks.4.attn.attn_drop │ Dropout │ 0 │ │ 228 │ net.blocks.4.attn.proj │ Linear │ 1.0 M │ │ 229 │ net.blocks.4.attn.proj_drop │ Dropout │ 0 │ │ 230 │ net.blocks.4.ls1 │ Identity │ 0 │ │ 231 │ net.blocks.4.drop_path1 │ DropPath │ 0 │ │ 232 │ net.blocks.4.norm2 │ LayerNorm │ 2.0 K │ │ 233 │ net.blocks.4.mlp │ Mlp │ 8.4 M │ │ 234 │ net.blocks.4.mlp.fc1 │ Linear │ 4.2 M │ │ 235 │ net.blocks.4.mlp.act │ GELU │ 0 │ │ 236 │ net.blocks.4.mlp.drop1 │ Dropout │ 0 │ │ 237 │ net.blocks.4.mlp.fc2 │ Linear │ 4.2 M │ │ 238 │ net.blocks.4.mlp.drop2 │ Dropout │ 0 │ │ 239 │ net.blocks.4.ls2 │ Identity │ 0 │ │ 240 │ net.blocks.4.drop_path2 │ DropPath │ 0 │ │ 241 │ net.blocks.5 │ Block │ 12.6 M │ │ 242 │ net.blocks.5.norm1 │ LayerNorm │ 2.0 K │ │ 243 │ net.blocks.5.attn │ Attention │ 4.2 M │ │ 244 │ net.blocks.5.attn.qkv │ Linear │ 3.1 M │ │ 245 │ net.blocks.5.attn.attn_drop │ Dropout │ 0 │ │ 246 │ net.blocks.5.attn.proj │ Linear │ 1.0 M │ │ 247 │ net.blocks.5.attn.proj_drop │ Dropout │ 0 │ │ 248 │ net.blocks.5.ls1 │ Identity │ 0 │ │ 249 │ net.blocks.5.drop_path1 │ DropPath │ 0 │ │ 250 │ net.blocks.5.norm2 │ LayerNorm │ 2.0 K │ │ 251 │ net.blocks.5.mlp │ Mlp │ 8.4 M │ │ 252 │ net.blocks.5.mlp.fc1 │ Linear │ 4.2 M │ │ 253 │ net.blocks.5.mlp.act │ GELU │ 0 │ │ 254 │ net.blocks.5.mlp.drop1 │ Dropout │ 0 │ │ 255 │ net.blocks.5.mlp.fc2 │ Linear │ 4.2 M │ │ 256 │ net.blocks.5.mlp.drop2 │ Dropout │ 0 │ │ 257 │ net.blocks.5.ls2 │ Identity │ 0 │ │ 258 │ net.blocks.5.drop_path2 │ DropPath │ 0 │ │ 259 │ net.blocks.6 │ Block │ 12.6 M │ │ 260 │ net.blocks.6.norm1 │ LayerNorm │ 2.0 K │ │ 261 │ net.blocks.6.attn │ Attention │ 4.2 M │ │ 262 │ net.blocks.6.attn.qkv │ Linear │ 3.1 M │ │ 263 │ net.blocks.6.attn.attn_drop │ Dropout │ 0 │ │ 264 │ net.blocks.6.attn.proj │ Linear │ 1.0 M │ │ 265 │ net.blocks.6.attn.proj_drop │ Dropout │ 0 │ │ 266 │ net.blocks.6.ls1 │ Identity │ 0 │ │ 267 │ net.blocks.6.drop_path1 │ DropPath │ 0 │ │ 268 │ net.blocks.6.norm2 │ LayerNorm │ 2.0 K │ │ 269 │ net.blocks.6.mlp │ Mlp │ 8.4 M │ │ 270 │ net.blocks.6.mlp.fc1 │ Linear │ 4.2 M │ │ 271 │ net.blocks.6.mlp.act │ GELU │ 0 │ │ 272 │ net.blocks.6.mlp.drop1 │ Dropout │ 0 │ │ 273 │ net.blocks.6.mlp.fc2 │ Linear │ 4.2 M │ │ 274 │ net.blocks.6.mlp.drop2 │ Dropout │ 0 │ │ 275 │ net.blocks.6.ls2 │ Identity │ 0 │ │ 276 │ net.blocks.6.drop_path2 │ DropPath │ 0 │ │ 277 │ net.blocks.7 │ Block │ 12.6 M │ │ 278 │ net.blocks.7.norm1 │ LayerNorm │ 2.0 K │ │ 279 │ net.blocks.7.attn │ Attention │ 4.2 M │ │ 280 │ net.blocks.7.attn.qkv │ Linear │ 3.1 M │ │ 281 │ net.blocks.7.attn.attn_drop │ Dropout │ 0 │ │ 282 │ net.blocks.7.attn.proj │ Linear │ 1.0 M │ │ 283 │ net.blocks.7.attn.proj_drop │ Dropout │ 0 │ │ 284 │ net.blocks.7.ls1 │ Identity │ 0 │ │ 285 │ net.blocks.7.drop_path1 │ DropPath │ 0 │ │ 286 │ net.blocks.7.norm2 │ LayerNorm │ 2.0 K │ │ 287 │ net.blocks.7.mlp │ Mlp │ 8.4 M │ │ 288 │ net.blocks.7.mlp.fc1 │ Linear │ 4.2 M │ │ 289 │ net.blocks.7.mlp.act │ GELU │ 0 │ │ 290 │ net.blocks.7.mlp.drop1 │ Dropout │ 0 │ │ 291 │ net.blocks.7.mlp.fc2 │ Linear │ 4.2 M │ │ 292 │ net.blocks.7.mlp.drop2 │ Dropout │ 0 │ │ 293 │ net.blocks.7.ls2 │ Identity │ 0 │ │ 294 │ net.blocks.7.drop_path2 │ DropPath │ 0 │ │ 295 │ net.norm │ LayerNorm │ 2.0 K │ │ 296 │ net.head │ Sequential │ 2.3 M │ │ 297 │ net.head.0 │ Linear │ 1.0 M │ │ 298 │ net.head.1 │ GELU │ 0 │ │ 299 │ net.head.2 │ Linear │ 1.0 M │ │ 300 │ net.head.3 │ GELU │ 0 │ │ 301 │ net.head.4 │ Linear │ 196 K │ │ 302 │ denormalization │ Normalize │ 0 │ └─────┴─────────────────────────────┴─────────────────────────────────┴────────┘ Trainable params: 108 M Non-trainable params: 0 Total params: 108 M Total estimated model params size (MB): 216 |