Skip to content

Commit

Permalink
fix(pu): fix last_game_segment bug in muzero_segment_collector.py
Browse files Browse the repository at this point in the history
  • Loading branch information
dyyoungg committed Sep 13, 2024
1 parent 51e10f2 commit 8615899
Show file tree
Hide file tree
Showing 7 changed files with 1,636 additions and 96 deletions.
2 changes: 1 addition & 1 deletion lzero/entry/train_unizero.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from lzero.policy import visit_count_temperature
from lzero.policy.random_policy import LightZeroRandomPolicy
# from lzero.worker import MuZeroCollector as Collector
from lzero.worker import MuZeroSegmentCollector as Collector # TODO
from lzero.worker import MuZeroSegmentCollector as Collector # ============ TODO: ============
from lzero.worker import MuZeroEvaluator as Evaluator
from .utils import random_collect

Expand Down
2 changes: 2 additions & 0 deletions lzero/model/unizero_world_models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,8 @@ def __init__(self, latent_recon_loss_weight=0, perceptual_loss_weight=0, **kwarg

# Define the weights for each loss type
self.obs_loss_weight = 10
# self.obs_loss_weight = 1 # for use_aug

self.reward_loss_weight = 1.
self.value_loss_weight = 0.25
self.policy_loss_weight = 1.
Expand Down
133 changes: 51 additions & 82 deletions lzero/worker/muzero_segment_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,16 @@ def reset(self, _policy: Optional[namedtuple] = None, _env: Optional[BaseEnvMana

self._env_info = {env_id: {'time': 0., 'step': 0} for env_id in range(self._env_num)}

# 在此处初始化action_mask_dict, to_play_dict和chance_dict,确保它们包含所有env_id的值
self.action_mask_dict = {i: None for i in range(self._env_num)}
self.to_play_dict = {i: None for i in range(self._env_num)}
if self.policy_config.use_ture_chance_label_in_chance_encoder:
self.chance_dict = {i: None for i in range(self._env_num)}

self.dones = np.array([False for _ in range(self._env_num)])
self.last_game_segments = [None for _ in range(self._env_num)]
self.last_game_priorities = [None for _ in range(self._env_num)]

self._episode_info = []
self._total_envstep_count = 0
self._total_episode_count = 0
Expand Down Expand Up @@ -356,10 +366,18 @@ def collect(self,
)
init_obs = self._env.ready_obs

action_mask_dict = {i: to_ndarray(init_obs[i]['action_mask']) for i in range(env_nums)}
to_play_dict = {i: to_ndarray(init_obs[i]['to_play']) for i in range(env_nums)}
if self.policy_config.use_ture_chance_label_in_chance_encoder:
chance_dict = {i: to_ndarray(init_obs[i]['chance']) for i in range(env_nums)}
# action_mask_dict = {i: to_ndarray(init_obs[i]['action_mask']) for i in range(env_nums)}
# to_play_dict = {i: to_ndarray(init_obs[i]['to_play']) for i in range(env_nums)}
# if self.policy_config.use_ture_chance_label_in_chance_encoder:
# chance_dict = {i: to_ndarray(init_obs[i]['chance']) for i in range(env_nums)}

# 改为直接使用self.action_mask_dict等变量
for env_id in range(env_nums):
if env_id in init_obs.keys():
self.action_mask_dict[env_id] = to_ndarray(init_obs[env_id]['action_mask'])
self.to_play_dict[env_id] = to_ndarray(init_obs[env_id]['to_play'])
if self.policy_config.use_ture_chance_label_in_chance_encoder:
self.chance_dict[env_id] = to_ndarray(init_obs[env_id]['chance'])

game_segments = [
GameSegment(
Expand All @@ -378,9 +396,10 @@ def collect(self,

game_segments[env_id].reset(observation_window_stack[env_id])

dones = np.array([False for _ in range(env_nums)])
last_game_segments = [None for _ in range(env_nums)]
last_game_priorities = [None for _ in range(env_nums)]
# dones = np.array([False for _ in range(env_nums)])
# last_game_segments = [None for _ in range(env_nums)]
# last_game_priorities = [None for _ in range(env_nums)]

# for priorities in self-play
search_values_lst = [[] for _ in range(env_nums)]
pred_values_lst = [[] for _ in range(env_nums)]
Expand All @@ -397,8 +416,6 @@ def collect(self,
self_play_visit_entropy = []
total_transitions = 0

ready_env_id = set()
remain_episode = n_episode
if collect_with_pure_policy:
temp_visit_list = [0.0 for i in range(self._env.action_space.n)]

Expand All @@ -410,19 +427,18 @@ def collect(self,
with self._timer:
# Get current ready env obs.
obs = self._env.ready_obs
new_available_env_id = set(obs.keys()).difference(ready_env_id)
ready_env_id = ready_env_id.union(set(list(new_available_env_id)[:remain_episode]))
remain_episode -= min(len(new_available_env_id), remain_episode)
ready_env_id = set(obs.keys())

stack_obs = {env_id: game_segments[env_id].get_obs() for env_id in ready_env_id}
stack_obs = list(stack_obs.values())

action_mask_dict = {env_id: action_mask_dict[env_id] for env_id in ready_env_id}
to_play_dict = {env_id: to_play_dict[env_id] for env_id in ready_env_id}
action_mask = [action_mask_dict[env_id] for env_id in ready_env_id]
to_play = [to_play_dict[env_id] for env_id in ready_env_id]
self.action_mask_dict_tmp = {env_id: self.action_mask_dict[env_id] for env_id in ready_env_id}
self.to_play_dict_tmp = {env_id: self.to_play_dict[env_id] for env_id in ready_env_id}

action_mask = [self.action_mask_dict_tmp[env_id] for env_id in ready_env_id]
to_play = [self.to_play_dict_tmp[env_id] for env_id in ready_env_id]
if self.policy_config.use_ture_chance_label_in_chance_encoder:
chance_dict = {env_id: chance_dict[env_id] for env_id in ready_env_id}
self.chance_dict_tmp = {env_id: self.chance_dict[env_id] for env_id in ready_env_id}

stack_obs = to_ndarray(stack_obs)
# return stack_obs shape: [B, S*C, W, H] e.g. [8, 4*1, 96, 96]
Expand Down Expand Up @@ -526,26 +542,26 @@ def collect(self,
# in ``game_segments[env_id].init``, we have appended o_{t} in ``self.obs_segment``
if self.policy_config.use_ture_chance_label_in_chance_encoder:
game_segments[env_id].append(
actions[env_id], to_ndarray(obs['observation']), reward, action_mask_dict[env_id],
to_play_dict[env_id], chance_dict[env_id]
actions[env_id], to_ndarray(obs['observation']), reward, self.action_mask_dict_tmp[env_id],
self.to_play_dict_tmp[env_id], self.chance_dict_tmp[env_id]
)
else:
game_segments[env_id].append(
actions[env_id], to_ndarray(obs['observation']), reward, action_mask_dict[env_id],
to_play_dict[env_id]
actions[env_id], to_ndarray(obs['observation']), reward, self.action_mask_dict_tmp[env_id],
self.to_play_dict_tmp[env_id]
)

# NOTE: the position of code snippet is very important.
# the obs['action_mask'] and obs['to_play'] are corresponding to the next action
action_mask_dict[env_id] = to_ndarray(obs['action_mask'])
to_play_dict[env_id] = to_ndarray(obs['to_play'])
self.action_mask_dict_tmp[env_id] = to_ndarray(obs['action_mask'])
self.to_play_dict_tmp[env_id] = to_ndarray(obs['to_play'])
if self.policy_config.use_ture_chance_label_in_chance_encoder:
chance_dict[env_id] = to_ndarray(obs['chance'])
self.chance_dict_tmp[env_id] = to_ndarray(obs['chance'])

if self.policy_config.ignore_done:
dones[env_id] = False
self.dones[env_id] = False
else:
dones[env_id] = done
self.dones[env_id] = done

if not collect_with_pure_policy:
visit_entropies_lst[env_id] += visit_entropy_dict[env_id]
Expand Down Expand Up @@ -575,10 +591,10 @@ def collect(self,
# if game segment is full, we will save the last game segment
if game_segments[env_id].is_full():
# pad over last segment trajectory
if last_game_segments[env_id] is not None:
if self.last_game_segments[env_id] is not None:
# TODO(pu): return the one game segment
self.pad_and_save_last_trajectory(
env_id, last_game_segments, last_game_priorities, game_segments, dones
env_id, self.last_game_segments, self.last_game_priorities, game_segments, self.dones
)

# calculate priority
Expand All @@ -589,8 +605,8 @@ def collect(self,
improved_policy_lst[env_id] = []

# the current game_segments become last_game_segment
last_game_segments[env_id] = game_segments[env_id]
last_game_priorities[env_id] = priorities
self.last_game_segments[env_id] = game_segments[env_id]
self.last_game_priorities[env_id] = priorities

# create new GameSegment
game_segments[env_id] = GameSegment(
Expand Down Expand Up @@ -628,9 +644,9 @@ def collect(self,

# NOTE: put the penultimate game segment in one episode into the trajectory_pool
# pad over 2th last game_segment using the last game_segment
if last_game_segments[env_id] is not None:
if self.last_game_segments[env_id] is not None:
self.pad_and_save_last_trajectory(
env_id, last_game_segments, last_game_priorities, game_segments, dones
env_id, self.last_game_segments, self.last_game_priorities, game_segments, self.dones
)

# store current segment trajectory
Expand All @@ -642,51 +658,8 @@ def collect(self,
# assert len(game_segments[env_id]) == len(priorities)
# NOTE: save the last game segment in one episode into the trajectory_pool if it's not null
if len(game_segments[env_id].reward_segment) != 0:
self.game_segment_pool.append((game_segments[env_id], priorities, dones[env_id]))

# print(game_segments[env_id].reward_segment)
# reset the finished env and init game_segments
if n_episode > self._env_num:
# Get current ready env obs.
init_obs = self._env.ready_obs
retry_waiting_time = 0.001
while len(init_obs.keys()) != self._env_num:
# In order to be compatible with subprocess env_manager, in which sometimes self._env_num is not equal to
# len(self._env.ready_obs), especially in tictactoe env.
self._logger.info('The current init_obs.keys() is {}'.format(init_obs.keys()))
self._logger.info('Before sleeping, the _env_states is {}'.format(self._env._env_states))
time.sleep(retry_waiting_time)
self._logger.info(
'=' * 10 + 'Wait for all environments (subprocess) to finish resetting.' + '=' * 10
)
self._logger.info(
'After sleeping {}s, the current _env_states is {}'.format(
retry_waiting_time, self._env._env_states
)
)
init_obs = self._env.ready_obs

new_available_env_id = set(init_obs.keys()).difference(ready_env_id)
ready_env_id = ready_env_id.union(set(list(new_available_env_id)[:remain_episode]))
remain_episode -= min(len(new_available_env_id), remain_episode)

action_mask_dict[env_id] = to_ndarray(init_obs[env_id]['action_mask'])
to_play_dict[env_id] = to_ndarray(init_obs[env_id]['to_play'])
if self.policy_config.use_ture_chance_label_in_chance_encoder:
chance_dict[env_id] = to_ndarray(init_obs[env_id]['chance'])
self.game_segment_pool.append((game_segments[env_id], priorities, self.dones[env_id]))

game_segments[env_id] = GameSegment(
self._env.action_space,
game_segment_length=self.policy_config.game_segment_length,
config=self.policy_config
)
observation_window_stack[env_id] = deque(
[init_obs[env_id]['observation'] for _ in range(self.policy_config.model.frame_stack_num)],
maxlen=self.policy_config.model.frame_stack_num
)
game_segments[env_id].reset(observation_window_stack[env_id])
last_game_segments[env_id] = None
last_game_priorities[env_id] = None

# log
self_play_moves_max = max(self_play_moves_max, eps_steps_lst[env_id])
Expand All @@ -705,17 +678,13 @@ def collect(self,
self._reset_stat(env_id)
ready_env_id.remove(env_id)

# if collected_episode >= n_episode:
# if self.collected_game_segments >= self._default_num_segments or one_episode_done: # game_segment_length = 400
# print(f'collect {self.collected_game_segments} segments now! one_episode_done: {one_episode_done}')
# self.collected_game_segments = 0

# 如果放到for循环里面去的v1版本,应该是丢失了部分环境的样本
# 下面的v2版本,是将<env_num>个环境的样本都正确返回了
if len(self.game_segment_pool) >= self._default_num_segments or one_episode_done: # game_segment_length = 400
# if len(self.game_segment_pool) >= self._default_num_segments or one_episode_done: # game_segment_length = 400
if len(self.game_segment_pool) >= self._default_num_segments: # game_segment_length = 400
print(f'collect {len(self.game_segment_pool)} segments now! one_episode_done: {one_episode_done}')
collected_enough_segments = True # 条件满足,设置标志变量为 True
# one_episode_done = False

# [data, meta_data]
return_data = [self.game_segment_pool[i][0] for i in range(len(self.game_segment_pool))], [
Expand Down
Loading

0 comments on commit 8615899

Please sign in to comment.