Commit 27ddcfde authored by Jiakai Song's avatar Jiakai Song
Browse files

commit

parent 431763bd
......@@ -3,7 +3,10 @@
<component name="ChangeListManager">
<list default="true" id="0cf971cc-497f-4c74-9aa1-f78fca398209" name="Default Changelist" comment="">
<change beforePath="$PROJECT_DIR$/.idea/workspace.xml" beforeDir="false" afterPath="$PROJECT_DIR$/.idea/workspace.xml" afterDir="false" />
<change beforePath="$PROJECT_DIR$/README.md" beforeDir="false" afterPath="$PROJECT_DIR$/README.md" afterDir="false" />
<change beforePath="$PROJECT_DIR$/algorithms/hppo.py" beforeDir="false" afterPath="$PROJECT_DIR$/algorithms/hppo.py" afterDir="false" />
<change beforePath="$PROJECT_DIR$/algorithms/pdqn.py" beforeDir="false" afterPath="$PROJECT_DIR$/algorithms/pdqn.py" afterDir="false" />
<change beforePath="$PROJECT_DIR$/run_hppo.py" beforeDir="false" afterPath="$PROJECT_DIR$/run_hppo.py" afterDir="false" />
<change beforePath="$PROJECT_DIR$/run_pdqn.py" beforeDir="false" afterPath="$PROJECT_DIR$/run_pdqn.py" afterDir="false" />
</list>
<option name="SHOW_DIALOG" value="false" />
<option name="HIGHLIGHT_CONFLICTS" value="true" />
......
......@@ -241,9 +241,7 @@ class HPPO(Agent):
entropy2_record.append(entropy2)
return np.mean(entropy1_record), np.mean(entropy2_record)
def save_model(self, save_dir=None):
if save_dir is None:
save_dir = './models/HPPO/'
def save_model(self, save_dir):
if not os.path.exists(save_dir):
os.makedirs(save_dir)
torch.save(self.actor.state_dict(), os.path.join(save_dir, 'actor.pkl'))
......
......@@ -288,13 +288,9 @@ class PDQN(Agent):
self.optim2.step()
self.step += 1
def save_model(self, save_dir=None):
if save_dir is None:
save_dir = './models/PDQN'
def save_model(self, save_dir):
if not os.path.exists(save_dir):
os.makedirs(save_dir)
info = {'hidden_size': self.hidden_size, 'frame_stack': self.frame_stack, 'mp': self.mp, 'squash': self.squash}
torch.save(info, os.path.join(save_dir, 'info.pkl'))
torch.save(self.q_network.state_dict(), os.path.join(save_dir, 'q_net.pkl'))
torch.save(self.p_net.state_dict(), os.path.join(save_dir, 'p_net.pkl'))
......
......@@ -31,7 +31,7 @@ if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--num_iteration', type=int, default=10000)
parser.add_argument('--save_interval', type=int, default=10)
parser.add_argument('--save_dir', type=str, default=None)
parser.add_argument('--save_dir', type=str, required=True)
parser.add_argument('--batch_size', type=int, default=2048)
parser.add_argument('--frame_stack', type=int, default=1)
parser.add_argument('--lr_a', type=float, default=0.0001)
......
......@@ -18,7 +18,7 @@ if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--num_episode', type=int, default=100000)
parser.add_argument('--save_interval', type=int, default=100)
parser.add_argument('--save_dir', type=str, default=None)
parser.add_argument('--save_dir', type=str, required=True)
parser.add_argument('--frame_stack', type=int, default=1)
parser.add_argument('--lr_q', type=float, default=0.0002)
parser.add_argument('--lr_p', type=float, default=0.0001)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment