Commit 40e90f39 authored by Jiakai Song's avatar Jiakai Song
Browse files

commit

parent c38ce25b
......@@ -3,8 +3,9 @@
<component name="ChangeListManager">
<list default="true" id="0cf971cc-497f-4c74-9aa1-f78fca398209" name="Default Changelist" comment="">
<change beforePath="$PROJECT_DIR$/.idea/workspace.xml" beforeDir="false" afterPath="$PROJECT_DIR$/.idea/workspace.xml" afterDir="false" />
<change beforePath="$PROJECT_DIR$/algorithms/common.py" beforeDir="false" afterPath="$PROJECT_DIR$/algorithms/agent.py" afterDir="false" />
<change beforePath="$PROJECT_DIR$/algorithms/hppo.py" beforeDir="false" afterPath="$PROJECT_DIR$/algorithms/hppo.py" afterDir="false" />
<change beforePath="$PROJECT_DIR$/run_hppo.py" beforeDir="false" afterPath="$PROJECT_DIR$/run_hppo.py" afterDir="false" />
<change beforePath="$PROJECT_DIR$/algorithms/pdqn.py" beforeDir="false" afterPath="$PROJECT_DIR$/algorithms/pdqn.py" afterDir="false" />
</list>
<option name="SHOW_DIALOG" value="false" />
<option name="HIGHLIGHT_CONFLICTS" value="true" />
......@@ -71,6 +72,22 @@
<screen x="65" y="24" width="1855" height="1056" />
</state>
<state x="1180" y="271" width="424" height="482" key="FileChooserDialogImpl/65.24.1855.1056@65.24.1855.1056" timestamp="1630564222384" />
<state width="1830" height="277" key="GridCell.Tab.0.bottom" timestamp="1630569508698">
<screen x="65" y="24" width="1855" height="1056" />
</state>
<state width="1830" height="277" key="GridCell.Tab.0.bottom/65.24.1855.1056@65.24.1855.1056" timestamp="1630569508698" />
<state width="1830" height="277" key="GridCell.Tab.0.center" timestamp="1630569508697">
<screen x="65" y="24" width="1855" height="1056" />
</state>
<state width="1830" height="277" key="GridCell.Tab.0.center/65.24.1855.1056@65.24.1855.1056" timestamp="1630569508697" />
<state width="1830" height="277" key="GridCell.Tab.0.left" timestamp="1630569508697">
<screen x="65" y="24" width="1855" height="1056" />
</state>
<state width="1830" height="277" key="GridCell.Tab.0.left/65.24.1855.1056@65.24.1855.1056" timestamp="1630569508697" />
<state width="1830" height="277" key="GridCell.Tab.0.right" timestamp="1630569508698">
<screen x="65" y="24" width="1855" height="1056" />
</state>
<state width="1830" height="277" key="GridCell.Tab.0.right/65.24.1855.1056@65.24.1855.1056" timestamp="1630569508698" />
<state x="485" y="174" key="SettingsEditor" timestamp="1630562526490">
<screen x="65" y="24" width="1855" height="1056" />
</state>
......
......@@ -3,7 +3,7 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
import os
from .common import Agent
from .agent import Agent
from torch.distributions import normal, Categorical
......@@ -12,9 +12,42 @@ def layer_init(layer, std=1.0, bias_const=0.0):
torch.nn.init.constant_(layer.bias, bias_const)
class PSActor(nn.Module):
def __init__(self, input_size, n_discrete, params_size, hidden_size=None):
super(PSActor, self).__init__()
if hidden_size is None:
hidden_size = [256, 256, 256]
self.layers = nn.ModuleList([nn.Linear(input_size, hidden_size[0])])
for x, y in zip(hidden_size[:-1], hidden_size[1:]):
self.layers.append(nn.Linear(x, y))
self.discrete_action = nn.Linear(hidden_size[-1], n_discrete)
self.mu = nn.Linear(hidden_size[-1], params_size)
self.log_std = nn.Parameter(-1.0 * torch.ones([1, params_size]), requires_grad=True)
for layer in self.layers:
layer_init(layer, std=1.0)
layer_init(self.discrete_action, std=1.0)
layer_init(self.mu, std=1.0)
def forward(self, state, avail_actions):
x = state
for hidden_layer in self.layers:
x = F.relu(hidden_layer(x))
discrete_action = self.discrete_action(x)
discrete_action[avail_actions == 0] = -999999
prob = torch.softmax(discrete_action, dim=-1)
categorical = Categorical(prob)
mu = torch.tanh(self.mu(x))
# mu = self.mu(continuous)
std = torch.exp(self.log_std)
dist = normal.Normal(mu, std)
return categorical, dist
class Actor(nn.Module):
def __init__(self, input_size, n_discrete, params_size, hidden_size=None):
super(Actor, self).__init__()
super(Actor, self).__init__()
if hidden_size is None:
hidden_size = [256, 256, 256]
......
......@@ -7,7 +7,7 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from .common import Agent
from .agent import Agent
Transition = namedtuple("Transition",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment