|
- import torch
- from envs.make_env_funcs import make_env_fn
- from envs.vec_dummy_env import DummyVecEnv
- from common.common_tools import get_config
-
- from torch_agents.agents.ppo_agent import PPO_Agent
- from torch_agents.utils.backbones import MLP_Backbone, MLP_GRU_Backbone, MLP_LSTM_Backbone
- from torch_agents.policies.discrete import DiscreteActorCriticNet
- from torch_agents.policies.gaussian import GaussianActorCriticNet
- # Define Your Config
- config = get_config("configs/ppo/ppo_toy.yaml")
- # Define Your Environment
- envs = DummyVecEnv([make_env_fn(config.environment,i) for i in range(config.nenvs)])
- observation_space = envs.observation_space
- action_space = envs.action_space
-
- # Define Your Network
- # backbone = MLP_Backbone(observation_space,
- # hidden_size=(64,),
- # activation=torch.nn.Tanh,
- # initialize=torch.nn.init.xavier_normal,
- # device = config.device)
- backbone = MLP_GRU_Backbone(observation_space,
- 64,
- hidden_size=(),
- activation=torch.nn.LeakyReLU,
- initialize=torch.nn.init.xavier_normal,
- device=config.device)
- # backbone = MLP_LSTM_Backbone(observation_space,
- # 64,
- # hidden_size=(),
- # activation=torch.nn.LeakyReLU,
- # initialize=torch.nn.init.xavier_normal,
- # device=config.device)
- policy = DiscreteActorCriticNet(action_space,
- backbone,
- actor_hidden_size=(64,),
- critic_hidden_size=(64,),
- activation=torch.nn.Tanh,
- initialize=torch.nn.init.xavier_normal,
- device = config.device)
- # policy = GaussianActorCriticNet(action_space,
- # backbone,
- # actor_hidden_size=(64,),
- # critic_hidden_size=(64,),
- # activation=torch.nn.Tanh,
- # initialize=torch.nn.init.xavier_normal,
- # device = config.device)
- # Define Your Optimizer
- optimizer = torch.optim.Adam(policy.parameters(),lr=config.learning_rate,eps=1e-5)
- schedular = torch.optim.lr_scheduler.StepLR(optimizer,step_size=100,gamma=0.99,last_epoch=-1)
- # Define Your Agent
- agent = PPO_Agent(config,envs,policy,optimizer,schedular,device=config.device)
- agent.train(int(config.init_steps),int(config.train_steps),None,10000)
- mu,std = agent.test(50)
- print("test_mean={},test_std={}".format(mu,std))
|