Commit c38ce25b authored by Jiakai Song's avatar Jiakai Song
Browse files

commit

parent c65df947
......@@ -3,7 +3,8 @@
<component name="ChangeListManager">
<list default="true" id="0cf971cc-497f-4c74-9aa1-f78fca398209" name="Default Changelist" comment="">
<change beforePath="$PROJECT_DIR$/.idea/workspace.xml" beforeDir="false" afterPath="$PROJECT_DIR$/.idea/workspace.xml" afterDir="false" />
<change beforePath="$PROJECT_DIR$/envs/soccer_env.py" beforeDir="false" afterPath="$PROJECT_DIR$/envs/soccer_env.py" afterDir="false" />
<change beforePath="$PROJECT_DIR$/algorithms/hppo.py" beforeDir="false" afterPath="$PROJECT_DIR$/algorithms/hppo.py" afterDir="false" />
<change beforePath="$PROJECT_DIR$/run_hppo.py" beforeDir="false" afterPath="$PROJECT_DIR$/run_hppo.py" afterDir="false" />
</list>
<option name="SHOW_DIALOG" value="false" />
<option name="HIGHLIGHT_CONFLICTS" value="true" />
......@@ -24,6 +25,34 @@
<property name="last_opened_file_path" value="$PROJECT_DIR$" />
<property name="settings.editor.selected.configurable" value="com.jetbrains.python.configuration.PyActiveSdkModuleConfigurable" />
</component>
<component name="RunManager">
<configuration name="run_hppo" type="PythonConfigurationType" factoryName="Python" temporary="true" nameIsGenerated="true">
<module name="pdqn_hppo" />
<option name="INTERPRETER_OPTIONS" value="" />
<option name="PARENT_ENVS" value="true" />
<envs>
<env name="PYTHONUNBUFFERED" value="1" />
</envs>
<option name="SDK_HOME" value="" />
<option name="WORKING_DIRECTORY" value="$PROJECT_DIR$" />
<option name="IS_MODULE_SDK" value="true" />
<option name="ADD_CONTENT_ROOTS" value="true" />
<option name="ADD_SOURCE_ROOTS" value="true" />
<option name="SCRIPT_NAME" value="$PROJECT_DIR$/run_hppo.py" />
<option name="PARAMETERS" value="" />
<option name="SHOW_COMMAND_LINE" value="false" />
<option name="EMULATE_TERMINAL" value="false" />
<option name="MODULE_MODE" value="false" />
<option name="REDIRECT_INPUT" value="false" />
<option name="INPUT_FILE" value="" />
<method v="2" />
</configuration>
<recent_temporary>
<list>
<item itemvalue="Python.run_hppo" />
</list>
</recent_temporary>
</component>
<component name="SvnConfiguration">
<configuration />
</component>
......
......@@ -102,7 +102,7 @@ class HPPO(Agent):
self.lr_a = args.lr_a
self.lr_c = args.lr_c
self.gamma, self.gae_lam = args.gamma, args.gae_lam
self.mini_batch = args.mini_batch
self.n_mini_batch = args.n_mini_batch
self.epsilon = args.epsilon
self.epochs = args.epochs
......@@ -131,25 +131,25 @@ class HPPO(Agent):
params = params.cpu().squeeze(0).numpy()
return self.denormalize(action, params), action, params, log_prob1, log_prob2
def update_step(self, s_batch, a_batch, p_batch, avail_batch, old_log_p1, old_log_p2, returns_batch, adv_batch):
v = self.critic(s_batch)
categorical, dist = self.actor(s_batch, avail_batch)
log_p1 = categorical.log_prob(a_batch)
log_p2 = dist.log_prob(p_batch).sum(-1)
def update_step(self, mini_batch):
v = self.critic(mini_batch['s'])
categorical, dist = self.actor(mini_batch['s'], mini_batch['avail'])
log_p1 = categorical.log_prob(mini_batch['a'])
log_p2 = dist.log_prob(mini_batch['p']).sum(-1)
entropy1 = categorical.entropy().mean()
entropy2 = dist.entropy().mean()
# entropy2 = dist.entropy().sum(-1).mean()
ratio1 = torch.exp(log_p1 - old_log_p1)
ratio2 = torch.exp(log_p2 - old_log_p2)
ratio1 = torch.exp(log_p1 - mini_batch['log_prob1'])
ratio2 = torch.exp(log_p2 - mini_batch['log_prob2'])
discrete_loss = -torch.mean(torch.min(
torch.clamp(ratio1, 1 - self.epsilon, 1 + self.epsilon) * adv_batch,
ratio1 * adv_batch
torch.clamp(ratio1, 1 - self.epsilon, 1 + self.epsilon) * mini_batch['adv'],
ratio1 * mini_batch['adv']
))
continuous_loss = -torch.mean(torch.min(
torch.clamp(ratio2, 1 - self.epsilon, 1 + self.epsilon) * adv_batch,
ratio2 * adv_batch
torch.clamp(ratio2, 1 - self.epsilon, 1 + self.epsilon) * mini_batch['adv'],
ratio2 * mini_batch['adv']
))
action_loss = discrete_loss + continuous_loss
......@@ -161,7 +161,7 @@ class HPPO(Agent):
torch.nn.utils.clip_grad_norm_(self.actor.parameters(), self.grad_clip)
self.optim1.step()
value_loss = F.mse_loss(v, returns_batch)
value_loss = F.mse_loss(v, mini_batch['returns'])
self.optim2.zero_grad()
value_loss.backward()
if self.grad_clip is not None:
......@@ -170,57 +170,54 @@ class HPPO(Agent):
return entropy1.item(), entropy2.item()
def update_network(self, buffer_s, buffer_a, buffer_p, buffer_avail, log_p1, log_p2, buffer_r, buffer_mask, n):
s = torch.tensor(buffer_s, dtype=torch.float32).to(self.device)
p = torch.tensor(buffer_p, dtype=torch.float32).to(self.device)
r = torch.tensor(buffer_r, dtype=torch.float32).to(self.device)
avail = torch.tensor(buffer_avail, dtype=torch.float32).to(self.device)
log_prob1 = torch.tensor(log_p1, dtype=torch.float32).to(self.device)
log_prob2 = torch.tensor(log_p2, dtype=torch.float32).to(self.device)
v_s = self.critic(s).detach()
action_idx = torch.tensor(buffer_a, dtype=torch.int64).to(self.device)
mask = torch.tensor(buffer_mask, dtype=torch.float32).to(self.device)
def update_network(self, batch):
for key in batch.keys():
if key == 'a':
batch[key] = torch.tensor(batch[key], dtype=torch.int64).to(self.device)
else:
batch[key] = torch.tensor(batch[key], dtype=torch.float32).to(self.device)
n = batch['s'].shape[0]
adv = torch.zeros([n], dtype=torch.float32).to(self.device)
detlas = torch.zeros([n], dtype=torch.float32).to(self.device)
returns = torch.zeros([n], dtype=torch.float32).to(self.device)
v_s = self.critic(batch['s']).detach()
pre_return = 0
pre_adv = 0
pre_v = 0
for i in reversed(range(n)):
returns[i] = r[i] + self.gamma * pre_return * mask[i]
detlas[i] = r[i] + self.gamma * pre_v * mask[i] - v_s[i]
adv[i] = detlas[i] + self.gamma * self.gae_lam * pre_adv * mask[i]
returns[i] = batch['r'][i] + self.gamma * pre_return * batch['mask'][i]
detlas[i] = batch['r'][i] + self.gamma * pre_v * batch['mask'][i] - v_s[i]
adv[i] = detlas[i] + self.gamma * self.gae_lam * pre_adv * batch['mask'][i]
pre_v = v_s[i]
pre_adv = adv[i]
pre_return = returns[i]
batch['adv'] = adv
batch['returns'] = returns
if self.adv_norm:
adv = (adv - adv.mean()) / (adv.std() + 1e-8)
adv.clamp_(-10.0, 10.0)
shuffle = np.random.permutation(n)
mini_batch_size = n // self.mini_batch
mini_batch_size = n // self.n_mini_batch
entropy1_record = []
entropy2_record = []
for _ in range(self.epochs):
for i in range(self.mini_batch):
if i == self.mini_batch - 1:
minibatch = shuffle[i * mini_batch_size:n]
for i in range(self.n_mini_batch):
if i == self.n_mini_batch - 1:
indices = shuffle[i * mini_batch_size:n]
else:
minibatch = shuffle[i * mini_batch_size:(i + 1) * mini_batch_size]
s_batch = s[minibatch]
returns_batch = returns[minibatch]
adv_batch = adv[minibatch]
a_batch = action_idx[minibatch]
p_batch = p[minibatch]
log_p1_batch = log_prob1[minibatch]
log_p2_batch = log_prob2[minibatch]
avail_batch = avail[minibatch]
e1, e2 = self.update_step(s_batch, a_batch, p_batch, avail_batch, log_p1_batch, log_p2_batch, returns_batch, adv_batch)
entropy1_record.append(e1)
entropy2_record.append(e2)
indices = shuffle[i * mini_batch_size:(i + 1) * mini_batch_size]
mini_batch = {}
for key in batch.keys():
if key != 'r' and key != 'mask':
mini_batch[key] = batch[key][indices]
entropy1, entropy2 = self.update_step(mini_batch)
entropy1_record.append(entropy1)
entropy2_record.append(entropy2)
return np.mean(entropy1_record), np.mean(entropy2_record)
def save_model(self, save_dir=None):
......
......@@ -36,7 +36,7 @@ if __name__ == '__main__':
parser.add_argument('--frame_stack', type=int, default=1)
parser.add_argument('--lr_a', type=float, default=0.0001)
parser.add_argument('--lr_c', type=float, default=0.0002)
parser.add_argument('--mini_batch', type=int, default=4)
parser.add_argument('--n_mini_batch', type=int, default=4)
parser.add_argument('--epochs', type=int, default=4)
parser.add_argument('--gamma', type=float, default=0.99)
parser.add_argument('--gae_lam', type=float, default=0.97)
......@@ -100,17 +100,18 @@ if __name__ == '__main__':
batch = memory.sample()
n = len(memory)
state_batch = batch.state
action_batch = batch.action
params_batch = batch.params
log_prob1_batch = batch.log_prob1
log_prob2_batch = batch.log_prob2
reward_batch = batch.reward
mask_batch = batch.mask
avail_batch = batch.avail_actions
batch = {
's': batch.state,
'a': batch.action,
'p': batch.params,
'log_prob1': batch.log_prob1,
'log_prob2': batch.log_prob2,
'r': batch.reward,
'mask': batch.mask,
'avail': batch.avail_actions,
}
e1, e2 = agent.update_network(state_batch, action_batch, params_batch, avail_batch, log_prob1_batch,
log_prob2_batch, reward_batch, mask_batch, n)
entropy1, entropy2 = agent.update_network(batch)
mean_ep_reward = reward_record[-1]['mean_ep_reward']
mean_ep_length = reward_record[-1]['mean_ep_length']
......@@ -119,8 +120,8 @@ if __name__ == '__main__':
writer.add_scalar('reward/iteration', mean_ep_reward, iteration)
writer.add_scalar('episode_length/iteration', mean_ep_length, iteration)
writer.add_scalar('entropy1/iteration', e1, iteration)
writer.add_scalar('entropy2/iteration', e2, iteration)
writer.add_scalar('entropy1/iteration', entropy1, iteration)
writer.add_scalar('entropy2/iteration', entropy2, iteration)
if (iteration + 1) % args.save_interval == 0:
agent.save_model(args.save_dir)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment