Source code for virne.solver.learning.reinforcement_learning.cnn_solver

# ==============================================================================
# Copyright 2023 GeminiLight (wtfly2018@gmail.com). All Rights Reserved.
# ==============================================================================


import numpy as np
from gym import spaces
from typing import Any, Dict, Tuple, List, Union, Optional, Type, Callable


import torch
import numpy as np
from omegaconf import DictConfig, open_dict

from virne.network import PhysicalNetwork, VirtualNetwork
from virne.core import Controller, Recorder, Counter, Solution, Logger

from virne.solver import SolverRegistry
from virne.solver.learning.rl_core import JointPRStepInstanceRLEnv, PlaceStepInstanceRLEnv
from virne.solver.learning.rl_core.rl_solver import PGSolver, A2CSolver, PPOSolver, A3CSolver
from virne.solver.learning.rl_core.instance_agent import InstanceAgent
from virne.solver.learning.rl_core.tensor_convertor import TensorConvertor
from virne.solver.learning.rl_core.policy_builder import PolicyBuilder
from virne.solver.learning.reinforcement_learning.solver_maker import make_solver_class


obs_as_tensor = TensorConvertor.obs_as_tensor_for_cnn
build_policy = PolicyBuilder.build_cnn_policy


class PgCnnInstanceRLEnv(PlaceStepInstanceRLEnv):
    def __init__(self, p_net: PhysicalNetwork, v_net: VirtualNetwork, controller: Controller, recorder: Recorder, counter: Counter, logger: Logger, config: DictConfig, **kwargs):
        with open_dict(config):
            config.rl.feature_constructor.name = 'p_net_v_node'
            config.rl.feature_constructor.if_use_degree_metric = True
            config.rl.feature_constructor.if_use_more_topological_metrics = False
            config.rl.feature_constructor.if_use_aggregated_link_attrs = True
            config.rl.feature_constructor.if_use_node_status_flags = True
            config.rl.reward_calculator.name = 'vanilla'
            config.rl.if_use_negative_sample = False
        super(PgCnnInstanceRLEnv, self).__init__(p_net, v_net, controller, recorder, counter, logger, config, **kwargs)


[docs] @SolverRegistry.register(solver_name='pg_cnn', solver_type='r_learning') class PgCnnSolver(InstanceAgent, PGSolver): """ A Reinforcement Learning-based solver that uses Policy Gradient (PG) as the training algorithm and Convolutional Neural Network (CNN) as the neural network model. """ def __init__(self, controller, recorder, counter, logger, config, **kwargs): InstanceAgent.__init__(self, PgCnnInstanceRLEnv) PGSolver.__init__(self, controller, recorder, counter, logger, config, build_policy, obs_as_tensor, **kwargs)
class PgCnn2InstanceRLEnv(PlaceStepInstanceRLEnv): def __init__(self, p_net: PhysicalNetwork, v_net: VirtualNetwork, controller: Controller, recorder: Recorder, counter: Counter, logger: Logger, config: DictConfig, **kwargs): with open_dict(config): config.rl.feature_constructor.name = 'p_net_v_node' config.rl.feature_constructor.if_use_degree_metric = True config.rl.feature_constructor.if_use_more_topological_metrics = True config.rl.feature_constructor.if_use_aggregated_link_attrs = True config.rl.feature_constructor.if_use_node_status_flags = True config.rl.reward_calculator.name = 'vanilla' super(PgCnn2InstanceRLEnv, self).__init__(p_net, v_net, controller, recorder, counter, logger, config, **kwargs)
[docs] @SolverRegistry.register(solver_name='pg_cnn2', solver_type='r_learning') class PgCnn2Solver(InstanceAgent, PGSolver): """ A Reinforcement Learning-based solver that uses Policy Gradient (PG) as the training algorithm and Convolutional Neural Network (CNN) as the neural network model. Additionally, more graph features are used as the input of the CNN. """ def __init__(self, controller, recorder, counter, logger, config, **kwargs): InstanceAgent.__init__(self, PgCnn2InstanceRLEnv) PGSolver.__init__(self, controller, recorder, counter, logger, config, build_policy, obs_as_tensor, **kwargs)
class CnnInstanceEnv(JointPRStepInstanceRLEnv): def __init__(self, p_net: PhysicalNetwork, v_net: VirtualNetwork, controller: Controller, recorder: Recorder, counter: Counter, logger: Logger, config: DictConfig, **kwargs): with open_dict(config): config.rl.feature_constructor.name = 'p_net_v_node' super(CnnInstanceEnv, self).__init__(p_net, v_net, controller, recorder, counter, logger, config, **kwargs) extented_cnn_solvers = [ {'solver_name': 'pg_cnn+', 'solver_cls_name': 'PgCnnSolver', 'rl_solver_cls': PGSolver}, {'solver_name': 'a2c_cnn+', 'solver_cls_name': 'A2cCnnSolver', 'rl_solver_cls': A2CSolver}, {'solver_name': 'a3c_cnn+', 'solver_cls_name': 'A3cCnnSolver', 'rl_solver_cls': A3CSolver}, {'solver_name': 'ppo_cnn+', 'solver_cls_name': 'PpoCnnSolver', 'rl_solver_cls': PPOSolver}, ] for solver_info in extented_cnn_solvers: # make and register the solver class solver_name = solver_info['solver_name'] policy_key = '_'.join(solver_info['solver_name'].split('_')[1:])[:-1] policy_builder = build_policy obs_as_tensor = TensorConvertor.obs_as_tensor_for_cnn instance_env_cls = CnnInstanceEnv base_solver_cls = solver_info['rl_solver_cls'] make_solver_class(solver_name, instance_env_cls, base_solver_cls, policy_builder, obs_as_tensor)