Source code for virne.solver.learning.pg_seq2seq.solver

# ==============================================================================
# Copyright 2023 GeminiLight (wtfly2018@gmail.com). All Rights Reserved.
# ==============================================================================


import os
import time
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Categorical

from virne.base.environment import SolutionStepEnvironment
from virne.solver import registry
from virne.base import Solution
from virne.solver.learning.rl_base.buffer import RolloutBuffer
from .instance_env import InstanceEnv
from .net import ActorCritic
from virne.solver.learning.rl_base import InstanceAgent, RLSolver, PGSolver


[docs]@registry.register( solver_name='pg_seq2seq', env_cls=SolutionStepEnvironment, solver_type='r_learning') class PgSeq2SeqSolver(InstanceAgent, PGSolver): """ A Reinforcement Learning-based solver that uses Policy Gradient (PG) as the training algorithm and Sequence-to-Sequence (Seq2Seq) as the neural network model. """ def __init__(self, controller, recorder, counter, **kwargs): InstanceAgent.__init__(self, InstanceEnv) PGSolver.__init__(self, controller, recorder, counter, make_policy, obs_as_tensor, **kwargs) self.preprocess_encoder_obs = encoder_obs_to_tensor self.compute_advantage_method = 'mc'
[docs] def solve(self, instance): v_net, p_net = instance['v_net'], instance['p_net'] sub_env = self.InstanceEnv(p_net, v_net, self.controller, self.recorder, self.counter, **self.basic_config) instance_obs = sub_env.get_observation() instance_done = False outputs = self.policy.encode(self.preprocess_encoder_obs(instance_obs, device=self.device)) p_node_id = p_net.num_nodes while not instance_done: hidden_state, cell_state = self.policy.get_last_rnn_state() instance_obs = { 'p_node_id': p_node_id, 'hidden_state': np.squeeze(hidden_state.cpu().detach().numpy(), axis=0), 'cell_state': np.squeeze(cell_state.cpu().detach().numpy(), axis=0), } mask = np.expand_dims(sub_env.generate_action_mask(), axis=0) tensor_instance_obs = self.preprocess_obs(instance_obs, device=self.device) action, action_logprob = self.select_action(tensor_instance_obs, mask=mask, sample=True) next_instance_obs, instance_reward, instance_done, instance_info = sub_env.step(action) p_node_id = action.item() if instance_done: break instance_obs = next_instance_obs return sub_env.solution
def learn_with_instance(self, instance): # sub env for sub agent sub_buffer = RolloutBuffer() v_net, p_net = instance['v_net'], instance['p_net'] sub_env = self.InstanceEnv(p_net, v_net, self.controller, self.recorder, self.counter, **self.basic_config) instance_obs = sub_env.get_observation() instance_done = False outputs = self.policy.encode(self.preprocess_encoder_obs(instance_obs, device=self.device)) p_node_id = p_net.num_nodes while not instance_done: hidden_state, cell_state = self.policy.get_last_rnn_state() instance_obs = { 'p_node_id': p_node_id, 'hidden_state': np.squeeze(hidden_state.cpu().detach().numpy(), axis=0), 'cell_state': np.squeeze(cell_state.cpu().detach().numpy(), axis=0), 'action_mask': np.expand_dims(sub_env.generate_action_mask(), axis=0) } tensor_instance_obs = self.preprocess_obs(instance_obs, device=self.device) action, action_logprob = self.select_action(tensor_instance_obs, mask=mask, sample=True) value = self.estimate_value(tensor_instance_obs) if hasattr(self.policy, 'evaluate') else None next_instance_obs, instance_reward, instance_done, instance_info = sub_env.step(action) p_node_id = action.item() sub_buffer.add(instance_obs, action, instance_reward, instance_done, action_logprob, value=value) sub_buffer.action_masks.append(mask) if instance_done: break instance_obs = next_instance_obs last_value = self.estimate_value(self.preprocess_obs(next_instance_obs, self.device)) if hasattr(self.policy, 'evaluate') else None solution = sub_env.solution return solution, sub_buffer, last_value
def make_policy(agent, **kwargs): feature_dim = 3 # (n_attrs, e_attrs, dist, degree) action_dim = 100 policy = ActorCritic(feature_dim, action_dim, agent.embedding_dim).to(agent.device) optimizer = torch.optim.Adam([ {'params': policy.parameters(), 'lr': agent.lr_actor}, ], weight_decay=agent.weight_decay) return policy, optimizer def encoder_obs_to_tensor(obs, device): # one if isinstance(obs, dict): """Preprocess the observation to adapte to batch mode.""" obs_p_net_x = torch.FloatTensor(obs['p_net_x']).unsqueeze(dim=0).to(device) return {'p_net_x': obs_p_net_x} # batch elif isinstance(obs, list): obs_batch = obs p_net_x_list = [] for observation in obs_batch: p_net_x = observation['p_net_x'] p_net_x_list.append(p_net_x) obs_p_net_x = torch.FloatTensor(np.array(p_net_x_list)).to(device) return {'p_net_x': obs_p_net_x} def obs_as_tensor(obs, device): # one if isinstance(obs, dict): """Preprocess the observation to adapte to batch mode.""" obs_p_node_id = torch.LongTensor([obs['p_node_id']]).to(device) obs_hidden_state = torch.FloatTensor(obs['hidden_state']).unsqueeze(dim=1).to(device) obs_cell_state = torch.FloatTensor(obs['cell_state']).unsqueeze(dim=1).to(device) return {'p_node_id': obs_p_node_id, 'hidden_state': obs_hidden_state, 'cell_state': obs_cell_state} # batch elif isinstance(obs, list): obs_batch = obs p_node_id_list, hidden_state_list, cell_state_list = [], [], [] for observation in obs_batch: p_node_id_list.append(observation['p_node_id']) hidden_state_list.append(observation['hidden_state']) cell_state_list.append(observation['cell_state']) obs_p_node_id = torch.LongTensor(np.array(p_node_id_list)).to(device) obs_hidden_state = torch.FloatTensor(np.array(hidden_state_list)).permute(1, 0, 2).to(device) obs_cell_state = torch.FloatTensor(np.array(cell_state_list)).permute(1, 0, 2).to(device) return {'p_node_id': obs_p_node_id, 'hidden_state': obs_hidden_state, 'cell_state': obs_cell_state} else: raise ValueError('obs type error')