Source code for free_range_zoo.envs.cybersecurity.env.structures.configuration
"""Configuration classes for the cybersecurity environment."""
from __future__ import annotations
from typing import Tuple
from dataclasses import dataclass
import functools
import torch
from free_range_zoo.utils.configuration import Configuration
from free_range_zoo.envs.cybersecurity.env.transitions import (MovementTransition, PresenceTransition, SubnetworkTransition)
[docs]
@dataclass
class CybersecurityConfiguration(Configuration):
"""
Configuration for the cybersecurity environment.
Attributes:
attacker_config: AttackerConfiguration - Configuration for the attacker agent properties
defender_config: DefenderConfiguration - Configuration for the defender agent properties
network_config: NetworkConfiguration - Configuration for the network nodes
reward_config: RewardConfiguration - Configuration for the environment rewards
stochastic_config: StochasticConfiguration - Configuration for the stochastic components of the environment
"""
attacker_config: AttackerConfiguration
defender_config: DefenderConfiguration
network_config: NetworkConfiguration
reward_config: RewardConfiguration
stochastic_config: StochasticConfiguration
@functools.cached_property
def movement_transition(self) -> MovementTransition:
"""Get the movement transition function for the environment."""
return MovementTransition()
@functools.cached_property
def presence_transition(self) -> PresenceTransition:
"""Get the presence transition function for the environment."""
return PresenceTransition(self.persist_probs, self.return_probs, self.attacker_config.num_attackers)
@functools.cached_property
def subnetwork_transition(self) -> SubnetworkTransition:
"""Get the subnetwork transition function for the environment."""
return SubnetworkTransition(self.network_config.patched_states, self.network_config.vulnerable_states,
self.network_config.exploited_states, self.network_config.temperature,
self.stochastic_config.network_state)
@functools.cached_property
def attacker_observation_bounds(self) -> torch.Tensor:
"""Get the observation bounds for the agent (threat, presence)."""
return tuple([self.attacker_config.highest_threat, 1])
@functools.cached_property
def defender_observation_bounds(self) -> torch.Tensor:
"""Get the observation bounds for the agent (mitigation, presence, location)."""
return tuple([self.defender_config.highest_mitigation, 1, self.network_config.num_nodes - 1])
@functools.cached_property
def network_observation_bounds(self) -> Tuple[int, int]:
"""Get the observation bounds for the subnetwork (state)."""
return tuple([self.network_config.num_states])
@functools.cached_property
def num_agents(self) -> int:
"""Get the number of agents of all types in the environment."""
return self.attacker_config.num_attackers + self.defender_config.num_defenders
@functools.cached_property
def persist_probs(self) -> torch.FloatTensor:
"""Get the persist probabilities for all agents."""
return torch.cat([self.attacker_config.persist_probs, self.defender_config.persist_probs])
@functools.cached_property
def return_probs(self) -> torch.FloatTensor:
"""Get the return probabilities for all agents."""
return torch.cat([self.attacker_config.return_probs, self.defender_config.return_probs])
@functools.cached_property
def initial_presence(self) -> torch.BoolTensor:
"""Get the initial presence of all agents."""
return torch.cat([self.attacker_config.initial_presence, self.defender_config.initial_presence])
def validate(self) -> bool:
"""
Validate the configuration.
Returns:
bool - True if the configuration is valid, nothing otherwise
"""
if self.reward_config.network_state_rewards.size(0) != self.network_config.num_states:
raise ValueError("The number of network state rewards must match the number of network states.")
return True
[docs]
@dataclass
class AttackerConfiguration(Configuration):
"""
Configuration for the attacker in the cybersecurity environment.
Attributes:
initial_presence: torch.BoolTensor - Initial presence of each attacking agent
threat: torch.FloatTensor - Threat values for each attacking agent
persist_probs: torch.FloatTensor - Probability for each attacking agent to leave the environment
return_probs: torch.FloatTensor - Probability for each attacking agent to return to the environment
"""
initial_presence: torch.BoolTensor
threat: torch.FloatTensor
persist_probs: torch.FloatTensor
return_probs: torch.FloatTensor
@functools.cached_property
def num_attackers(self) -> int:
"""Get the number of attackers."""
return self.threat.size(0)
@functools.cached_property
def highest_threat(self) -> float:
"""Get the highest threat value of all attackers."""
return self.threat.max().item()
def validate(self) -> bool:
"""Validate the configuration."""
if self.persist_probs.min() < 0 or self.persist_probs.max() > 1:
raise ValueError('Persist probabilities must be between 0 and 1.')
if self.return_probs.min() < 0 or self.return_probs.max() > 1:
raise ValueError('Return probabilities must be between 0 and 1.')
if self.threat.size(0) != self.persist_probs.size(0) or self.threat.size(0) != self.return_probs.size(0):
raise ValueError("The size of threats must match the size of persist and return probabilities.")
if self.threat.size(0) != self.initial_presence.size(0):
raise ValueError("The size of threats must match the size of initial presence values.")
return True
[docs]
@dataclass
class DefenderConfiguration(Configuration):
"""
Configuration for the defender in the cybersecurity environment.
Attributes:
initial_location: torch.IntTensor - Initial location of each defending agent
initial_presence: torch.BoolTensor - Initial presence of each defending agent
mitigation: torch.FloatTensor - mitigation values for each defending agent
persist_probs: torch.FloatTensor - Probability for each defending agent to leave the environment
return_probs: torch.FloatTensor - Probability for each defending agent to return to the environment
"""
initial_location: torch.IntTensor
initial_presence: torch.BoolTensor
mitigation: torch.FloatTensor
persist_probs: torch.FloatTensor
return_probs: torch.FloatTensor
@functools.cached_property
def num_defenders(self) -> int:
"""Get the number of defenders."""
return self.mitigation.size(0)
@functools.cached_property
def highest_mitigation(self) -> float:
"""Get the highest mitigation value of all defenders."""
return self.mitigation.max().item()
def validate(self) -> bool:
"""Validate the configuration."""
if self.persist_probs.min() < 0 or self.persist_probs.max() > 1:
raise ValueError('Persist probabilities must be between 0 and 1.')
if self.return_probs.min() < 0 or self.return_probs.max() > 1:
raise ValueError('Return probabilities must be between 0 and 1.')
if self.mitigation.size(0) != self.persist_probs.size(0) or self.mitigation.size(0) != self.return_probs.size(0):
raise ValueError("The size of mitigations must match the size of persist and return probabilities.")
if self.mitigation.size(0) != self.initial_location.size(0) or self.mitigation.size(0) != self.initial_presence.size(0):
raise ValueError("The size of mitigations must match the size of initial location and presence values.")
return True
[docs]
@dataclass
class NetworkConfiguration(Configuration):
"""
Configuration for the network components of the cybersecurity simulation.
The home node for the simulation is automatically defined as node -1.
Attributes:
patched_states: int - Number of patched states in the network
vulnerable_states: int - Number of vulnerable states in the network
exploited_states: int - Number of exploited states in the network
temperature: float - Temperature for the softmax function for the danger score
initial_state: torch.IntTensor - Subnetwork-parallel array representing the exploitment state of each subnetwork
adj_matrix: torch.BoolTensor - 2D array representing adjacency matrix for all subnetwork connections
"""
patched_states: int
vulnerable_states: int
exploited_states: int
temperature: float
initial_state: torch.IntTensor
adj_matrix: torch.BoolTensor
@functools.cached_property
def criticality(self) -> torch.FloatTensor:
"""Get the criticality of each node. Based on the number of outward connections."""
return self.adj_matrix.sum(dim=1)
@functools.cached_property
def num_nodes(self) -> int:
"""Get the number of nodes in the network."""
return self.adj_matrix.size(0)
@functools.cached_property
def num_states(self) -> int:
"""Get the number of states in the network."""
return self.patched_states + self.vulnerable_states + self.exploited_states
def validate(self) -> bool:
"""Validate the configuration."""
if self.initial_state.size(0) != self.num_nodes:
raise ValueError("The size of initial state must match the number of nodes.")
if self.adj_matrix.size(0) != self.adj_matrix.size(1):
raise ValueError("The adjacency matrix must be square.")
return True
[docs]
@dataclass
class StochasticConfiguration(Configuration):
"""
Configuration for the stochastic components of the cybersecurity simulation.
Attributes:
network_state: bool - Whether the subnetwork states degrade / repair stochastically
"""
network_state: bool
def validate(self) -> bool:
"""Validate the configuration."""
return True
[docs]
@dataclass
class RewardConfiguration(Configuration):
"""
Configuration for the rewards in the cybersecurity environment.
Attributes:
bad_action_penalty: float - Penalty for committing a bad action (patching while at the home node)
patch_reward: float - Reward (or penalty) for patching a node
network_state_rewards: torch.FloatTensor - Subnetwork-parallel array representing the rewards for each
"""
bad_action_penalty: float
patch_reward: float
network_state_rewards: torch.FloatTensor
def validate(self) -> bool:
"""Validate the configuration."""
return True