Commit 683c5b8d authored by YU Xiyue's avatar YU Xiyue
Browse files

5

parent ec35af70
Pipeline #721 canceled with stages
......@@ -16,7 +16,7 @@ class Mixer_Layer(nn.Module):
# 这里需要写Mixer_Layer(layernorm,mlp1,mlp2,skip_connection)
tokens_mlp_dim = 256
channels_mlp_dim = 32
S = 28 * 28 / (patch_size * patch_size)
S = int((28 * 28 / (patch_size * patch_size)))
self.ln_token = nn.LayerNorm(hidden_dim)
self.mlp1 = nn.Sequential(
nn.Linear(S, tokens_mlp_dim),
......@@ -27,16 +27,16 @@ class Mixer_Layer(nn.Module):
self.mlp2 = nn.Sequential(
nn.Linear(hidden_dim, channels_mlp_dim),
nn.GELU(),
nn.Linear(channels_mlp_dim, S)
nn.Linear(channels_mlp_dim, hidden_dim)
)
########################################################################
def forward(self, x):
########################################################################
u = self.ln_token(x).transpose(1, 2)
x = x + self.mlp1(u).transpose(1, 2)
y = self.ln_channel(x)
return y + self.mlp2(self.ln_channel(y))
u = x + self.mlp1(u).transpose(1, 2)
y = self.ln_channel(u)
return u + self.mlp2(y)
########################################################################
......@@ -56,7 +56,7 @@ class MLPMixer(nn.Module):
def forward(self, data):
########################################################################
# 注意维度的变化
y = self.patch_ppfc(data)
y = self.ppfc(data)
y = torch.flatten(y, start_dim=2).transpose(1, 2)
y = self.mlp(y)
y = self.ln(y)
......@@ -73,7 +73,11 @@ def train(model, train_loader, optimizer, n_epochs, criterion):
data, target = data.to(device), target.to(device)
########################################################################
# 计算loss并进行优化
loss = criterion()
output = model(data)
loss = criterion(output, target)
optimizer.zero_grad()
loss.backward()
optimizer.step()
########################################################################
if batch_idx % 100 == 0:
print('Train Epoch: {}/{} [{}/{}]\tLoss: {:.6f}'.format(
......@@ -89,7 +93,11 @@ def test(model, test_loader, criterion):
data, target = data.to(device), target.to(device)
########################################################################
# 需要计算测试集的loss和accuracy
output = model(data)
pred = output.argmax(dim=-1)
corr = (pred == target).sum()
accuracy = corr.item() / len(data)
test_loss = criterion(output, target)
########################################################################
print("Test set: Average loss: {:.4f}\t Acc {:.2f}".format(test_loss.item(), accuracy))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment