"""
# Wildfire
## Description
The wildfire domain simulates a grid-based environment where fires can spread and agents are tasked with extinguishing
them by applying suppressant. The environment is dynamic and partially observable, with fires that can spread across
adjacent tiles and vary in intensity. Fires can also burn out once they reach a certain intensity threshold.
<u>**Environment Dynamics**</u><br>
- Fire Spread: Fires start at designated locations and spread to neighboring tiles, increasing in intensity over
time. The intensity of the fire influences how much suppressant is needed to extinguish it. Fires will continue
to spread until they either burn out or are controlled by agents.
- Fire Intensity and Burnout: As fires spread, their intensity increases, making them harder to fight. Once a
fire reaches a critical intensity, it may burn out naturally, stopping its spread and extinguishing itself.
However, this is unpredictable, and timely intervention is often necessary to prevent further damage.
- Suppression Mechanism: Agents apply suppressant to the fire to reduce its intensity. However, suppressant is a
finite resource. When an agent runs out of suppressant, they must leave the environment to refill at a designated
station before returning to continue fighting fires.
<u>**Environment Openness**</u><br>
- **agent openness**: Environments where agents can dynamically enter and leave, enabling dynamic cooperation and
multi-agent scenarios with evolving participants.
- `wildfire`: Agents can run out of suppressant and leave the environment, removing their contributions
to existing fires. Agents must reason about their collaborators leaving, or new collaborators entering.
- **task openness**: Tasks can be introduced or removed from the environment, allowing for flexbile goal setting
and adaptable planning models
- `wildfire`: Fires can spread beyond their original starting point, requiring agents to reason about new
tasks possibly entering the environment as well as a changing action space: Fires can spread beyond
their original starting point, requiring agents to reason about new tasks possibly entering the
environment as well as a changing action space.
- **frame / type openness**: Different frames (e.g. agent abilities or skills) can be added, removed, or modified,
expending the environmental complexity and requiring agents to infer their neighbors changing abilities.
- `wildfire`: Agents can damage their equipment over time, and have their capabilities slowly degrade. On
the other hand, agents might also recieve different equipment upon leaving the environment to resupply.
# Specification
---
| Import | `from free_range_zoo.envs import wildfire_v0` |
| ------------------ | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
| Actions | Discrete & Stochastic |
| Observations | Discrete and fully Observed with private observations |
| Parallel API | Yes |
| Manual Control | No |
| Agent Names | [$firefighter$_0, ..., $firefighter$_n] |
| # Agents | [0, $n_{firefighters}$] |
| Action Shape | ($envs$, 2) |
| Action Values | [$fight_0$, ..., $fight_{tasks}$, $noop$ (-1)] |
| Observation Shape | TensorDict: { <br> **self**: $<ypos, xpos, fire power, suppressant>$<br> **others**: $<ypos,xpos,fire power, suppressant>$<br> **tasks**: $<y, x, fire level, intensity>$ <br> **batch_size**: $num\_envs$ } |
| Observation Values | <u>**self**</u>:<br> $ypos$: $[0, max_y]$<br> $xpos$: $[0, max_x]$<br> $fire\_power\_reduction$: $[0, max_{fire\_power\_reduction}]$<br> $suppressant$: $[0, max_{suppressant}]$<br><u>**others**</u>:<br> $ypos$: $[0, max_y]$<br> $xpos$: $[0, max_x]$<br> $fire\_power\_reduction$: $[0, max_{fire\_power\_reduction}]$<br> $suppressant$: $[0, max_{suppressant}]$<br> <u>**tasks**</u><br> $ypos$: $[0, max_y]$<br> $xpos$: $[0, max_x]$<br> $fire\_level$: $[0, max_{fire\_level}]$<br> $intensity$: $[0, num_{fire\_states}]$ |
---
"""
from typing import Tuple, Dict, Any, Union, List, Optional, Callable
import torch
from tensordict.tensordict import TensorDict
import gymnasium
from pettingzoo.utils.wrappers import OrderEnforcingWrapper
from free_range_zoo.utils.env import BatchedAECEnv
from free_range_zoo.utils.conversions import batched_aec_to_batched_parallel
from free_range_zoo.envs.wildfire.env import transitions
from free_range_zoo.envs.wildfire.env.utils import in_range_check
from free_range_zoo.envs.wildfire.env.spaces import actions, observations
from free_range_zoo.envs.wildfire.env.structures.state import WildfireState
def parallel_env(wrappers: List[Callable] = [], **kwargs) -> BatchedAECEnv:
"""
Paralellized version of the wildfire environment.
Args:
wrappers: List[Callable[[BatchedAECEnv], BatchedAECEnv]] - the wrappers to apply to the environment
Returns:
BatchedAECEnv: the parallelized wildfire environment
"""
env = raw_env(**kwargs)
env = OrderEnforcingWrapper(env)
for wrapper in wrappers:
env = wrapper(env)
env = batched_aec_to_batched_parallel(env)
return env
[docs]
def env(wrappers: List[Callable], **kwargs) -> BatchedAECEnv:
"""
AEC wrapped version of the wildfire environment.
Args:
wrappers: List[Callable[[BatchedAECEnv], BatchedAECEnv]] - the wrappers to apply to the environment
Returns:
BatchedAECEnv: the AEC wrapped wildfire environment
"""
env = raw_env(**kwargs)
env = OrderEnforcingWrapper(env)
for wrapper in wrappers:
env = wrapper(env)
return env
[docs]
class raw_env(BatchedAECEnv):
"""Environment definition for the wildfire environment."""
metadata = {"render.modes": ["human", "rgb_array"], "name": "wildfire_v0", "is_parallelizable": True, "render_fps": 2}
@torch.no_grad()
def __init__(self,
*args,
observe_other_suppressant: bool = False,
observe_other_power: bool = False,
show_bad_actions: bool = False,
**kwargs) -> None:
"""
Initialize the Wildfire environment.
Args:
observe_others_suppressant: bool - whether to observe the suppressant of other agents
observe_other_power: bool - whether to observe the power of other agents
show_bad_actions: bool - whether to show bad actions
"""
super().__init__(*args, **kwargs)
self.observe_other_suppressant = observe_other_suppressant
self.observe_other_power = observe_other_power
self.show_bad_actions = show_bad_actions
self.possible_agents = tuple(f"firefighter_{i}" for i in range(1, self.agent_config.num_agents + 1))
self.agent_name_mapping = dict(zip(self.possible_agents, torch.arange(0, len(self.possible_agents), device=self.device)))
self.agent_position_mapping = dict(zip(self.possible_agents, self.agent_config.agents))
self.ignition_temp = self.fire_config.ignition_temp
self.max_x = self.config.grid_width
self.max_y = self.config.grid_height
# Set the transition filter for the fire spread
self.fire_spread_weights = self.config.fire_spread_weights.to(self.device)
# Pre-create range indices tensors for use in later operations
self.fire_index_ranges = torch.arange(self.parallel_envs * self.max_y * self.max_x, device=self.device)
self.parallel_ranges = torch.arange(self.parallel_envs, device=self.device)
# Create the agent mapping for observation ordering
agent_ids = torch.arange(0, self.agent_config.num_agents, device=self.device)
self.observation_ordering = {}
for agent in self.possible_agents:
agent_idx = self.agent_name_mapping[agent]
other_agents = agent_ids[agent_ids != agent_idx]
self.observation_ordering[agent] = other_agents
self.agent_observation_bounds = tuple([
self.max_y,
self.max_x,
self.agent_config.max_fire_reduction_power,
self.agent_config.suppressant_states,
])
self.fire_observation_bounds = tuple([
self.max_y,
self.max_x,
self.fire_config.max_fire_type,
self.fire_config.num_fire_states,
])
observation_mask = torch.ones(4, dtype=torch.bool, device=self.device)
observation_mask[2] = self.observe_other_power
observation_mask[3] = self.observe_other_suppressant
self.agent_observation_mask = lambda agent_name: observation_mask
# Initialize all of the transition layers based on the environment configurations
self.capacity_transition = transitions.CapacityTransition(
agent_shape=(self.parallel_envs, self.agent_config.num_agents),
stochastic_switch=self.stochastic_config.tank_switch,
tank_switch_probability=self.agent_config.tank_switch_probability,
possible_capacities=self.agent_config.possible_capacities,
capacity_probabilities=self.agent_config.capacity_probabilities,
).to(self.device)
self.equipment_transition = transitions.EquipmentTransition(
equipment_states=self.agent_config.equipment_states,
stochastic_repair=self.stochastic_config.repair,
repair_probability=self.agent_config.repair_probability,
stochastic_degrade=self.stochastic_config.degrade,
degrade_probability=self.agent_config.degrade_probability,
critical_error=self.stochastic_config.critical_error,
critical_error_probability=self.agent_config.critical_error_probability,
).to(self.device)
self.fire_decrease_transition = transitions.FireDecreaseTransition(
fire_shape=(self.parallel_envs, self.max_y, self.max_x),
stochastic_decrease=self.stochastic_config.fire_decrease,
decrease_probability=self.fire_config.intensity_decrease_probability,
extra_power_decrease_bonus=self.fire_config.extra_power_decrease_bonus,
).to(self.device)
self.fire_increase_transition = transitions.FireIncreaseTransition(
fire_shape=(self.parallel_envs, self.max_y, self.max_x),
fire_states=self.fire_config.num_fire_states,
stochastic_increase=self.stochastic_config.fire_increase,
intensity_increase_probability=self.fire_config.intensity_increase_probability,
stochastic_burnouts=self.stochastic_config.special_burnout_probability,
burnout_probability=self.fire_config.burnout_probability,
).to(self.device)
self.fire_spread_transition = transitions.FireSpreadTransition(
fire_spread_weights=self.fire_spread_weights,
ignition_temperatures=self.ignition_temp,
use_fire_fuel=self.stochastic_config.fire_fuel,
).to(self.device)
self.suppressant_decrease_transition = transitions.SuppressantDecreaseTransition(
agent_shape=(self.parallel_envs, self.agent_config.num_agents),
stochastic_decrease=self.stochastic_config.suppressant_decrease,
decrease_probability=self.agent_config.suppressant_decrease_probability,
).to(self.device)
self.suppressant_refill_transition = transitions.SuppressantRefillTransition(
agent_shape=(self.parallel_envs, self.agent_config.num_agents),
stochastic_refill=self.stochastic_config.suppressant_refill,
refill_probability=self.agent_config.suppressant_refill_probability,
equipment_bonuses=self.agent_config.equipment_states[:, 0],
).to(self.device)
@torch.no_grad()
def reset(self, seed: Union[List[int], int] = None, options: Dict[str, Any] = None) -> None:
"""
Reset the environment.
Args:
seed: Union[List[int], int] - the seed to use
options: Dict[str, Any] - the options for the reset
"""
super().reset(seed=seed, options=options)
# Initialize the agent action to task mapping
self.agent_action_to_task_mapping = {agent: {} for agent in self.agents}
self.fire_reduction_power = self.agent_config.fire_reduction_power
self.suppressant_states = self.agent_config.suppressant_states
# Initialize the state
self._state = WildfireState(
fires=torch.zeros(
(self.parallel_envs, self.max_y, self.max_x),
dtype=torch.int32,
device=self.device,
),
intensity=torch.zeros(
(self.parallel_envs, self.max_y, self.max_x),
dtype=torch.int32,
device=self.device,
),
fuel=torch.zeros(
(self.parallel_envs, self.max_y, self.max_x),
dtype=torch.int32,
device=self.device,
),
agents=self.agent_config.agents,
capacity=torch.ones(
(self.parallel_envs, self.agent_config.num_agents),
dtype=torch.float32,
device=self.device,
),
suppressants=torch.ones(
(self.parallel_envs, self.agent_config.num_agents),
dtype=torch.float32,
device=self.device,
),
equipment=torch.ones(
(self.parallel_envs, self.agent_config.num_agents),
dtype=torch.int32,
device=self.device,
),
)
if options is not None and options.get('initial_state') is not None:
initial_state = options['initial_state']
if len(initial_state) != self.parallel_envs:
raise ValueError("Initial state must have the same number of environments as the parallel environments")
self._state = initial_state
else:
self._state.fires[:, self.fire_config.lit] = self.fire_config.fire_types[self.fire_config.lit]
self._state.fires[:, ~self.fire_config.lit] = -1 * self.fire_config.fire_types[~self.fire_config.lit]
self._state.intensity[:, self.fire_config.lit] = self.ignition_temp[self.fire_config.lit]
self._state.fuel[self._state.fires != 0] = self.fire_config.initial_fuel
self._state.suppressants[:, :] = self.agent_config.initial_suppressant
self._state.capacity[:, :] = self.agent_config.initial_capacity
self._state.equipment[:, :] = self.agent_config.initial_equipment_state
self._state.save_initial()
# Initialize the rewards for all environments
self.fire_rewards = self.reward_config.fire_rewards.unsqueeze(0).expand(self.parallel_envs, -1, -1)
self.num_burnouts = torch.zeros(self.parallel_envs, dtype=torch.int32, device=self.device)
# Intialize the mapping of the tasks "in" the environment, used to map actions
self.agent_action_mapping = {agent: None for agent in self.agents}
self.agent_observation_mapping = {agent: None for agent in self.agents}
self.agent_bad_actions = {agent: None for agent in self.agents}
# Set the observations and action space
if not options or not options.get('skip_observations', False):
self.update_observations()
if not options or not options.get('skip_actions', False):
self.update_actions()
self._post_reset_hook()
@torch.no_grad()
def reset_batches(self,
batch_indices: torch.Tensor,
seed: Optional[List[int]] = None,
options: Optional[Dict[str, Any]] = None) -> None:
"""
Partial reset of the environment for the given batch indices.
Args:
batch_indices: torch.Tensor - the batch indices to reset
seed: Optional[List[int]] - the seed to use
options: Optional[Dict[str, Any]] - the options for the reset
"""
super().reset_batches(batch_indices, seed, options)
# Reset the state
self._state.restore_initial(batch_indices)
self.num_burnouts[batch_indices] = 0
# Reset the observation updates
self.update_observations()
self.update_actions()
@torch.no_grad()
def step_environment(self) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor], Dict[str, Dict[str, bool]]]:
"""
The actual simultaneous action wildfire environment step
"""
# Initialize storages
rewards = {agent: torch.zeros(self.parallel_envs, dtype=torch.float32, device=self.device) for agent in self.agents}
terminations = {agent: torch.zeros(self.parallel_envs, dtype=torch.bool, device=self.device) for agent in self.agents}
infos = {agent: {} for agent in self.agents}
# For simplification purposes, one randomness generation is done per step, then taken piecewise
field_randomness = self.generator.generate(self.parallel_envs, 3, (self.max_y, self.max_x), key='field')
agent_randomness = self.generator.generate(self.parallel_envs, 5, (self.agent_config.num_agents, ), key='agent')
shape = (self.agent_config.num_agents, self.parallel_envs)
refills = torch.zeros(shape, dtype=torch.bool, device=self.device)
users = torch.zeros(shape, dtype=torch.bool, device=self.device)
attack_powers = torch.zeros_like(self._state.fires, dtype=torch.float32, device=self.device)
fire_positions = (self._state.fires > 0).nonzero(as_tuple=False)
index_shifts = torch.roll(torch.cumsum(self.environment_task_count, dim=0), 1, dims=0)
index_shifts[0] = 0
index_shifts = index_shifts[self.parallel_ranges.unsqueeze(1)]
# Loop over each agent
for agent_name, agent_actions in self.actions.items():
agent_index = self.agent_name_mapping[agent_name]
# Determine in which environments this agent is refilling suppressant
refills[agent_index] = agent_actions[:, 1] == -1
if self.agent_task_count[agent_index].sum() == 0:
continue
agent_action_mapping_pad = self.agent_action_mapping[agent_name].to_padded_tensor(padding=-100)
# Gather the indexes of all of the agent actions
task_indices = torch.hstack([self.parallel_ranges.unsqueeze(1), agent_actions[:, 0].unsqueeze(1)])
task_indices = agent_action_mapping_pad[task_indices[~refills[agent_index]].split(1, dim=1)]
global_task_indices = (agent_actions[:, 0] + index_shifts.squeeze(1))[~refills[agent_index]]
fire_coords = fire_positions[global_task_indices]
reduction_powers = self.fire_reduction_power[agent_index].expand(self.parallel_envs)
equipment_bonuses = self.agent_config.equipment_states[self._state.equipment[:, agent_index].unsqueeze(0)][:, :, 1]
full_powers = (reduction_powers + equipment_bonuses).squeeze(0)
# Create a fight tensor that we will update to filter bad actions
good_fight = torch.ones(self.parallel_envs, dtype=torch.bool, device=self.device)
good_fight[refills[agent_index]] = False
if self.show_bad_actions and self.agent_bad_actions[agent_name].numel() > 0:
bad_actions = self.agent_bad_actions[agent_name].to_padded_tensor(-100)
attacks = agent_actions[:, 0].unsqueeze(1).expand_as(bad_actions)
bad_actions = (bad_actions == attacks).any(dim=1)
good_fight = good_fight & (~bad_actions & ~refills[agent_index])
attack_powers[fire_coords[good_fight[~refills[agent_index]]].split(1, dim=1)] += full_powers[good_fight].unsqueeze(1)
# Aggregate the filtered information
users[agent_index] = good_fight
bad_users = ~(users[agent_index] | refills[agent_index])
# Give out rewards for bad actions
rewards[agent_name][bad_users] = self.reward_config.bad_attack_penalty
refills = refills.T
users = users.T
# Handle agent suppressant decrease
self._state = self.suppressant_decrease_transition(
state=self._state,
used_suppressants=users,
randomness_source=agent_randomness[0],
)
# Handle agent equipment transitions
self._state = self.equipment_transition(
state=self._state,
randomness_source=agent_randomness[1],
)
# Handle agent suppressant transitions
self._state, who_increased_suppressants = self.suppressant_refill_transition(
state=self._state,
refilled_suppressants=refills,
randomness_source=agent_randomness[2],
return_increased=True,
)
# Handle the agent capacity transitions
self._state = self.capacity_transition(
state=self._state,
targets=who_increased_suppressants,
randomness_source=agent_randomness[3:5],
)
# Handle fire intensity transitions
self._state, just_burned_out = self.fire_increase_transition(
state=self._state,
attack_counts=attack_powers,
randomness_source=field_randomness[0],
return_burned_out=True,
)
self._state, just_put_out = self.fire_decrease_transition(
state=self._state,
attack_counts=attack_powers,
randomness_source=field_randomness[1],
return_put_out=True,
)
self._state = self.fire_spread_transition(
state=self._state,
randomness_source=field_randomness[2],
)
fire_rewards = torch.zeros_like(self._state.fires, device=self.device, dtype=torch.float)
fire_rewards[just_put_out] = self.fire_rewards[just_put_out]
fire_rewards[just_burned_out] = self.reward_config.burnout_penalty
fire_rewards_per_batch = fire_rewards.sum(dim=(1, 2))
self.num_burnouts += just_burned_out.int().sum(dim=(1, 2))
# Assign rewards
for agent in self.agents:
rewards[agent] += fire_rewards_per_batch
# Determine environment terminations due to no more fires
fires_are_out = self._state.fires.flatten(start_dim=1).max(dim=1)[0] <= 0
if self.stochastic_config.fire_fuel:
# If all fires are out and all fuel is depleted the episode is over
depleted_fuel = self._state.fuel.sum(dim=(1, 2)) <= 0
batch_is_dead = torch.logical_and(depleted_fuel, fires_are_out)
else:
# If all fires are out then the episode is over
batch_is_dead = fires_are_out
newly_terminated = torch.logical_xor(self.terminated, batch_is_dead)
termination_reward = self.reward_config.termination_reward / (self.num_burnouts + 1)
for agent in self.agents:
rewards[agent][newly_terminated] += termination_reward[newly_terminated]
terminations[agent] = batch_is_dead
return rewards, terminations, infos
@torch.no_grad()
def update_actions(self) -> None:
"""Update the action space for all agents."""
# Gather all the tasks in the environment
fires = self._state.fires > 0
num_tasks_per_environment = fires.sum(dim=(1, 2))
num_tasks = num_tasks_per_environment.sum()
# Gather the positions of all of the fires
fire_positions = fires.nonzero(as_tuple=False)
fire_positions_expanded = fire_positions.expand(self.agent_config.num_agents, -1, -1)
fire_positions_expanded = fire_positions_expanded.flatten(end_dim=1)
# Match up all of the agents with all of the tasks
agent_indices = torch.arange(0, self.agent_config.num_agents, device=self.device).unsqueeze(1)
agent_indices_expanded = agent_indices.expand(-1, num_tasks).flatten().unsqueeze(1)
agent_indices_expanded = torch.cat([fire_positions_expanded[:, 0].unsqueeze(1), agent_indices_expanded], dim=1)
agent_positions = self._state.agents.unsqueeze(1).expand(-1, num_tasks, -1).flatten(end_dim=1)
# Calculate the true range for each of the agents in each environment
agent_ranges = self.agent_config.attack_range[agent_indices_expanded[:, 1]].flatten()
equipment_states = self._state.equipment[agent_indices_expanded.split(1, dim=1)].squeeze(1)
range_bonuses = self.agent_config.equipment_states[equipment_states.unsqueeze(0)][:, :, 2].squeeze(0)
true_range = (agent_ranges + range_bonuses).flatten()
# Check which agents are in range of which tasks
in_range = in_range_check.chebyshev(
agent_position=agent_positions,
task_position=fire_positions_expanded[:, 1:],
attack_range=true_range,
).reshape(self.agent_config.num_agents, num_tasks)
# Check which agents have suppressants
has_suppressants = self._state.suppressants[agent_indices_expanded.split(1, dim=1)].squeeze(1) > 0
has_suppressants = has_suppressants.reshape(self.agent_config.num_agents, num_tasks)
# Combine the two checks to determine which agents can fight which fires
checks = torch.logical_and(in_range, has_suppressants)
index_shifts = torch.roll(torch.cumsum(num_tasks_per_environment, dim=0), 1, 0)
index_shifts[0] = 0
task_range = torch.arange(0, num_tasks, device=self.device)
task_indices = (task_range - index_shifts[fire_positions[:, 0]].flatten())
task_indices_nested = torch.nested.as_nested_tensor(
task_indices.split(num_tasks_per_environment.tolist(), dim=0),
device=self.device,
)
# Aggregate the indices of all tasks to agent mapping
for agent, agent_number in self.agent_name_mapping.items():
# Count the number of tasks in each batch and aggregate the indices
task_count = torch.bincount(fire_positions[checks[agent_number]][:, 0], minlength=self.parallel_envs)
bad_task_count = num_tasks_per_environment - task_count
tasks = task_indices[checks[agent_number]]
batchwise_indices = torch.nested.as_nested_tensor(tasks.split(task_count.tolist(), dim=0), device=self.device)
bad_tasks = task_indices[~checks[agent_number]]
bad_batchwise_indices = torch.nested.as_nested_tensor(
bad_tasks.split(bad_task_count.tolist(), dim=0),
device=self.device,
)
self.agent_task_count[agent_number] = task_count
if self.show_bad_actions:
self.agent_action_mapping[agent] = task_indices_nested
self.agent_bad_actions[agent] = bad_batchwise_indices
else:
self.agent_action_mapping[agent] = batchwise_indices
self.agent_observation_mapping[agent] = task_indices_nested
self.environment_task_count = num_tasks_per_environment
@torch.no_grad()
def update_observations(self) -> None:
"""
Update the observations for the agents.
Observations consist of the following:
- Agent observation format: (batch, 1, (y, x, power, suppressant))
- Others observation format: (batch, agents - 1, (y, x, power, suppressant))
- Fire observation format: (batch, fires, (y, x, heat, intensity))
"""
# Build the agent observations
agent_positions = self._state.agents.expand(self.parallel_envs, -1, -1)
fire_reduction_power = self.fire_reduction_power.unsqueeze(-1).expand(self.parallel_envs, -1, -1)
suppressants = self._state.suppressants.unsqueeze(-1)
agent_observations = torch.cat((agent_positions, fire_reduction_power, suppressants), dim=2)
# Build the fire observations
lit_fires = torch.where(self._state.fires > 0, 1, 0)
lit_fire_indices = lit_fires.nonzero(as_tuple=False)
intensities = self._state.intensity[lit_fire_indices.split(1, dim=1)]
fires = self._state.fires[lit_fire_indices.split(1, dim=1)]
task_count = lit_fires.sum(dim=(1, 2))
fire_observations = torch.cat([lit_fire_indices[:, 1:], fires, intensities], dim=1)
fire_observations = torch.nested.as_nested_tensor(fire_observations.split(task_count.tolist(), dim=0))
self.task_store = fire_observations
# Aggregate the full observation space
self.observations = {}
for agent in self.agents:
observation_mask = self.agent_observation_mask(agent)
agent_index = self.agent_name_mapping[agent]
agent_mask = torch.ones(self.agent_config.num_agents, dtype=torch.bool, device=self.device)
agent_mask[agent_index] = False
self.observations[agent] = TensorDict(
{
'self': agent_observations[:, agent_index],
'others': agent_observations[:, agent_mask][:, :, observation_mask],
'tasks': fire_observations
},
batch_size=[self.parallel_envs],
device=self.device,
)
@torch.no_grad()
def action_space(self, agent: str) -> List[gymnasium.Space]:
"""
Return the action space for the given agent.
Args:
agent: str - the name of the agent to retrieve the action space for
Returns:
List[gymnasium.Space]: the action space for the given agent
"""
if self.show_bad_actions:
num_tasks_in_environment = self.environment_task_count
else:
num_tasks_in_environment = self.agent_task_count[self.agent_name_mapping[agent]]
return actions.build_action_space(environment_task_counts=num_tasks_in_environment)
@torch.no_grad()
def observation_space(self, agent: str) -> List[gymnasium.Space]:
"""
Return the observation space for the given agent.
Args:
agent: str - the name of the agent to retrieve the observation space for
Returns:
List[gymnasium.Space]: the observation space for the given agent
"""
return observations.build_observation_space(
environment_task_counts=self.environment_task_count,
num_agents=self.agent_config.num_agents,
agent_high=self.agent_observation_bounds,
fire_high=self.fire_observation_bounds,
include_suppressant=self.observe_other_suppressant,
include_power=self.observe_other_power,
)