Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature(pu): add pistonball_env, its unittest and qmix config #833

Merged
merged 17 commits into from
Nov 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 70 additions & 10 deletions ding/model/template/qmix.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from typing import Union, List
from functools import reduce
from typing import List, Union

import torch
import torch.nn as nn
import torch.nn.functional as F
from functools import reduce
from ding.utils import list_split, MODEL_REGISTRY
from ding.torch_utils import fc_block, MLP
from ding.torch_utils import MLP, fc_block
from ding.utils import MODEL_REGISTRY, list_split

from ..common import ConvEncoder
from .q_learning import DRQN


Expand Down Expand Up @@ -111,7 +114,7 @@ def __init__(
self,
agent_num: int,
obs_shape: int,
global_obs_shape: int,
global_obs_shape: Union[int, List[int]],
action_shape: int,
hidden_size_list: list,
mixer: bool = True,
Expand Down Expand Up @@ -146,8 +149,34 @@ def __init__(
embedding_size = hidden_size_list[-1]
self.mixer = mixer
if self.mixer:
self._mixer = Mixer(agent_num, global_obs_shape, embedding_size, activation=activation)
self._global_state_encoder = nn.Identity()
global_obs_shape_type = self._get_global_obs_shape_type(global_obs_shape)

if global_obs_shape_type == "flat":
self._mixer = Mixer(agent_num, global_obs_shape, embedding_size, activation=activation)
self._global_state_encoder = nn.Identity()
elif global_obs_shape_type == "image":
self._mixer = Mixer(agent_num, embedding_size, embedding_size, activation=activation)
self._global_state_encoder = ConvEncoder(
global_obs_shape, hidden_size_list=hidden_size_list, activation=activation, norm_type='BN'
)
else:
raise ValueError(f"Unsupported global_obs_shape: {global_obs_shape}")

def _get_global_obs_shape_type(self, global_obs_shape: Union[int, List[int]]) -> str:
"""
Overview:
Determine the type of global observation shape.
Arguments:
- global_obs_shape (:obj:`Union[int, List[int]]`): The global observation state.
Returns:
- obs_shape_type (:obj:`str`): 'flat' for 1D observation or 'image' for 3D observation.
"""
if isinstance(global_obs_shape, int) or (isinstance(global_obs_shape, list) and len(global_obs_shape) == 1):
return "flat"
elif isinstance(global_obs_shape, list) and len(global_obs_shape) == 3:
return "image"
else:
raise ValueError(f"Unsupported global_obs_shape: {global_obs_shape}")

def forward(self, data: dict, single_step: bool = True) -> dict:
"""
Expand Down Expand Up @@ -182,8 +211,16 @@ def forward(self, data: dict, single_step: bool = True) -> dict:
agent_state, global_state, prev_state = data['obs']['agent_state'], data['obs']['global_state'], data[
'prev_state']
action = data.get('action', None)
# If single_step is True, add a new dimension at the front of agent_state
# This is necessary to maintain the expected input shape for the model,
# which requires a time step dimension even when processing a single step.
if single_step:
agent_state, global_state = agent_state.unsqueeze(0), global_state.unsqueeze(0)
agent_state = agent_state.unsqueeze(0)
# If single_step is True and global_state has 2 dimensions, add a new dimension at the front of global_state
# This ensures that global_state has the same number of dimensions as agent_state,
# allowing for consistent processing in the forward computation.
if single_step and len(global_state.shape) == 2:
global_state = global_state.unsqueeze(0)
puyuan1996 marked this conversation as resolved.
Show resolved Hide resolved
T, B, A = agent_state.shape[:3]
assert len(prev_state) == B and all(
[len(p) == A for p in prev_state]
Expand All @@ -205,15 +242,38 @@ def forward(self, data: dict, single_step: bool = True) -> dict:
agent_q_act = torch.gather(agent_q, dim=-1, index=action.unsqueeze(-1))
agent_q_act = agent_q_act.squeeze(-1) # T, B, A
if self.mixer:
global_state_embedding = self._global_state_encoder(global_state)
global_state_embedding = self._process_global_state(global_state)
total_q = self._mixer(agent_q_act, global_state_embedding)
else:
total_q = agent_q_act.sum(-1)
total_q = agent_q_act.sum(dim=-1)

if single_step:
total_q, agent_q = total_q.squeeze(0), agent_q.squeeze(0)

return {
'total_q': total_q,
'logit': agent_q,
'next_state': next_state,
'action_mask': data['obs']['action_mask']
}

def _process_global_state(self, global_state: torch.Tensor) -> torch.Tensor:
"""
puyuan1996 marked this conversation as resolved.
Show resolved Hide resolved
Overview:
Process the global state to obtain an embedding.
Arguments:
- global_state (:obj:`torch.Tensor`): The global state tensor.

Returns:
- global_state_embedding (:obj:`torch.Tensor`): The processed global state embedding.
"""
# If global_state has 5 dimensions, it's likely in the form [batch_size, time_steps, C, H, W]
if global_state.dim() == 5:
# Reshape and apply the global state encoder
batch_time_shape = global_state.shape[:2] # [batch_size, time_steps]
reshaped_state = global_state.view(-1, *global_state.shape[-3:]) # Collapse batch and time dims
encoded_state = self._global_state_encoder(reshaped_state)
return encoded_state.view(*batch_time_shape, -1) # Reshape back to [batch_size, time_steps, embedding_dim]
else:
# For lower-dimensional states, apply the encoder directly
return self._global_state_encoder(global_state)
31 changes: 31 additions & 0 deletions ding/model/template/tests/test_qmix.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,34 @@ def test_qmix():
is_differentiable(loss, qmix_model)
data.pop('action')
output = qmix_model(data, single_step=False)


@pytest.mark.unittest
def test_qmix_process_global_state():
# Test the behavior of the _process_global_state method with different global_obs_shape types
agent_num, obs_dim, global_obs_dim, action_dim = 4, 32, 32 * 4, 9
embedding_dim = 64

# Case 1: Test "flat" type global_obs_shape
global_obs_shape = global_obs_dim # Flat global_obs_shape
qmix_model_flat = QMix(agent_num, obs_dim, global_obs_shape, action_dim, [64, 128, embedding_dim], mixer=True)

# Simulate input for the "flat" type global_state
batch_size, time_steps = 3, 8
global_state_flat = torch.randn(batch_size, time_steps, global_obs_dim)
processed_flat = qmix_model_flat._process_global_state(global_state_flat)

# Ensure the output shape is correct [batch_size, time_steps, embedding_dim]
assert processed_flat.shape == (batch_size, time_steps, global_obs_dim)

# Case 2: Test "image" type global_obs_shape
global_obs_shape = [3, 64, 64] # Image-shaped global_obs_shape (C, H, W)
qmix_model_image = QMix(agent_num, obs_dim, global_obs_shape, action_dim, [64, 128, embedding_dim], mixer=True)

# Simulate input for the "image" type global_state
C, H, W = global_obs_shape
global_state_image = torch.randn(batch_size, time_steps, C, H, W)
processed_image = qmix_model_image._process_global_state(global_state_image)

# Ensure the output shape is correct [batch_size, time_steps, embedding_dim]
assert processed_image.shape == (batch_size, time_steps, embedding_dim)
79 changes: 79 additions & 0 deletions dizoo/petting_zoo/config/ptz_pistonball_qmix_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
from easydict import EasyDict

n_pistons = 20
collector_env_num = 8
evaluator_env_num = 8
max_env_step = 3e6

main_config = dict(
exp_name=f'data_pistonball/ptz_pistonball_n{n_pistons}_qmix_seed0',
env=dict(
env_family='butterfly',
env_id='pistonball_v6',
n_pistons=n_pistons,
max_cycles=125,
agent_obs_only=False,
continuous_actions=False,
collector_env_num=collector_env_num,
evaluator_env_num=evaluator_env_num,
n_evaluator_episode=evaluator_env_num,
stop_value=1e6,
manager=dict(shared_memory=False,),
),
policy=dict(
cuda=True,
model=dict(
agent_num=n_pistons,
obs_shape=(3, 457, 120), # RGB image observation shape for each piston agent
global_obs_shape=(3, 560, 880), # Global state shape
action_shape=3, # Discrete actions (0, 1, 2)
hidden_size_list=[32, 64, 128, 256],
mixer=True,
),
learn=dict(
update_per_collect=20,
batch_size=32,
learning_rate=0.0001,
clip_value=5,
target_update_theta=0.001,
discount_factor=0.99,
double_q=True,
),
collect=dict(
n_sample=16,
unroll_len=5,
env_num=collector_env_num,
),
eval=dict(env_num=evaluator_env_num),
other=dict(
eps=dict(
type='exp',
start=1.0,
end=0.05,
decay=100000,
),
replay_buffer=dict(
replay_buffer_size=5000,
),
),
),
)
main_config = EasyDict(main_config)

create_config = dict(
env=dict(
import_names=['dizoo.petting_zoo.envs.petting_zoo_pistonball_env'],
type='petting_zoo_pistonball',
),
env_manager=dict(type='subprocess'),
policy=dict(type='qmix'),
)
create_config = EasyDict(create_config)

ptz_pistonball_qmix_config = main_config
ptz_pistonball_qmix_create_config = create_config

if __name__ == '__main__':
# or you can enter `ding -m serial -c ptz_pistonball_qmix_config.py -s 0`
from ding.entry import serial_pipeline
serial_pipeline((main_config, create_config), seed=0, max_env_step=max_env_step)
Loading
Loading