diff --git a/lzero/entry/train_unizero.py b/lzero/entry/train_unizero.py index e0edbe67c..2b08da0ed 100644 --- a/lzero/entry/train_unizero.py +++ b/lzero/entry/train_unizero.py @@ -18,7 +18,7 @@ from lzero.policy import visit_count_temperature from lzero.policy.random_policy import LightZeroRandomPolicy # from lzero.worker import MuZeroCollector as Collector -from lzero.worker import MuZeroSegmentCollector as Collector # TODO +from lzero.worker import MuZeroSegmentCollector as Collector # ============ TODO: ============ from lzero.worker import MuZeroEvaluator as Evaluator from .utils import random_collect diff --git a/lzero/model/unizero_world_models/utils.py b/lzero/model/unizero_world_models/utils.py index d27d08214..7bec06a41 100644 --- a/lzero/model/unizero_world_models/utils.py +++ b/lzero/model/unizero_world_models/utils.py @@ -301,6 +301,8 @@ def __init__(self, latent_recon_loss_weight=0, perceptual_loss_weight=0, **kwarg # Define the weights for each loss type self.obs_loss_weight = 10 + # self.obs_loss_weight = 1 # for use_aug + self.reward_loss_weight = 1. self.value_loss_weight = 0.25 self.policy_loss_weight = 1. diff --git a/lzero/worker/muzero_segment_collector.py b/lzero/worker/muzero_segment_collector.py index b6821ef9b..901c959a7 100644 --- a/lzero/worker/muzero_segment_collector.py +++ b/lzero/worker/muzero_segment_collector.py @@ -143,6 +143,16 @@ def reset(self, _policy: Optional[namedtuple] = None, _env: Optional[BaseEnvMana self._env_info = {env_id: {'time': 0., 'step': 0} for env_id in range(self._env_num)} + # 在此处初始化action_mask_dict, to_play_dict和chance_dict,确保它们包含所有env_id的值 + self.action_mask_dict = {i: None for i in range(self._env_num)} + self.to_play_dict = {i: None for i in range(self._env_num)} + if self.policy_config.use_ture_chance_label_in_chance_encoder: + self.chance_dict = {i: None for i in range(self._env_num)} + + self.dones = np.array([False for _ in range(self._env_num)]) + self.last_game_segments = [None for _ in range(self._env_num)] + self.last_game_priorities = [None for _ in range(self._env_num)] + self._episode_info = [] self._total_envstep_count = 0 self._total_episode_count = 0 @@ -356,10 +366,18 @@ def collect(self, ) init_obs = self._env.ready_obs - action_mask_dict = {i: to_ndarray(init_obs[i]['action_mask']) for i in range(env_nums)} - to_play_dict = {i: to_ndarray(init_obs[i]['to_play']) for i in range(env_nums)} - if self.policy_config.use_ture_chance_label_in_chance_encoder: - chance_dict = {i: to_ndarray(init_obs[i]['chance']) for i in range(env_nums)} + # action_mask_dict = {i: to_ndarray(init_obs[i]['action_mask']) for i in range(env_nums)} + # to_play_dict = {i: to_ndarray(init_obs[i]['to_play']) for i in range(env_nums)} + # if self.policy_config.use_ture_chance_label_in_chance_encoder: + # chance_dict = {i: to_ndarray(init_obs[i]['chance']) for i in range(env_nums)} + + # 改为直接使用self.action_mask_dict等变量 + for env_id in range(env_nums): + if env_id in init_obs.keys(): + self.action_mask_dict[env_id] = to_ndarray(init_obs[env_id]['action_mask']) + self.to_play_dict[env_id] = to_ndarray(init_obs[env_id]['to_play']) + if self.policy_config.use_ture_chance_label_in_chance_encoder: + self.chance_dict[env_id] = to_ndarray(init_obs[env_id]['chance']) game_segments = [ GameSegment( @@ -378,9 +396,10 @@ def collect(self, game_segments[env_id].reset(observation_window_stack[env_id]) - dones = np.array([False for _ in range(env_nums)]) - last_game_segments = [None for _ in range(env_nums)] - last_game_priorities = [None for _ in range(env_nums)] + # dones = np.array([False for _ in range(env_nums)]) + # last_game_segments = [None for _ in range(env_nums)] + # last_game_priorities = [None for _ in range(env_nums)] + # for priorities in self-play search_values_lst = [[] for _ in range(env_nums)] pred_values_lst = [[] for _ in range(env_nums)] @@ -397,8 +416,6 @@ def collect(self, self_play_visit_entropy = [] total_transitions = 0 - ready_env_id = set() - remain_episode = n_episode if collect_with_pure_policy: temp_visit_list = [0.0 for i in range(self._env.action_space.n)] @@ -410,19 +427,18 @@ def collect(self, with self._timer: # Get current ready env obs. obs = self._env.ready_obs - new_available_env_id = set(obs.keys()).difference(ready_env_id) - ready_env_id = ready_env_id.union(set(list(new_available_env_id)[:remain_episode])) - remain_episode -= min(len(new_available_env_id), remain_episode) + ready_env_id = set(obs.keys()) stack_obs = {env_id: game_segments[env_id].get_obs() for env_id in ready_env_id} stack_obs = list(stack_obs.values()) - action_mask_dict = {env_id: action_mask_dict[env_id] for env_id in ready_env_id} - to_play_dict = {env_id: to_play_dict[env_id] for env_id in ready_env_id} - action_mask = [action_mask_dict[env_id] for env_id in ready_env_id] - to_play = [to_play_dict[env_id] for env_id in ready_env_id] + self.action_mask_dict_tmp = {env_id: self.action_mask_dict[env_id] for env_id in ready_env_id} + self.to_play_dict_tmp = {env_id: self.to_play_dict[env_id] for env_id in ready_env_id} + + action_mask = [self.action_mask_dict_tmp[env_id] for env_id in ready_env_id] + to_play = [self.to_play_dict_tmp[env_id] for env_id in ready_env_id] if self.policy_config.use_ture_chance_label_in_chance_encoder: - chance_dict = {env_id: chance_dict[env_id] for env_id in ready_env_id} + self.chance_dict_tmp = {env_id: self.chance_dict[env_id] for env_id in ready_env_id} stack_obs = to_ndarray(stack_obs) # return stack_obs shape: [B, S*C, W, H] e.g. [8, 4*1, 96, 96] @@ -526,26 +542,26 @@ def collect(self, # in ``game_segments[env_id].init``, we have appended o_{t} in ``self.obs_segment`` if self.policy_config.use_ture_chance_label_in_chance_encoder: game_segments[env_id].append( - actions[env_id], to_ndarray(obs['observation']), reward, action_mask_dict[env_id], - to_play_dict[env_id], chance_dict[env_id] + actions[env_id], to_ndarray(obs['observation']), reward, self.action_mask_dict_tmp[env_id], + self.to_play_dict_tmp[env_id], self.chance_dict_tmp[env_id] ) else: game_segments[env_id].append( - actions[env_id], to_ndarray(obs['observation']), reward, action_mask_dict[env_id], - to_play_dict[env_id] + actions[env_id], to_ndarray(obs['observation']), reward, self.action_mask_dict_tmp[env_id], + self.to_play_dict_tmp[env_id] ) # NOTE: the position of code snippet is very important. # the obs['action_mask'] and obs['to_play'] are corresponding to the next action - action_mask_dict[env_id] = to_ndarray(obs['action_mask']) - to_play_dict[env_id] = to_ndarray(obs['to_play']) + self.action_mask_dict_tmp[env_id] = to_ndarray(obs['action_mask']) + self.to_play_dict_tmp[env_id] = to_ndarray(obs['to_play']) if self.policy_config.use_ture_chance_label_in_chance_encoder: - chance_dict[env_id] = to_ndarray(obs['chance']) + self.chance_dict_tmp[env_id] = to_ndarray(obs['chance']) if self.policy_config.ignore_done: - dones[env_id] = False + self.dones[env_id] = False else: - dones[env_id] = done + self.dones[env_id] = done if not collect_with_pure_policy: visit_entropies_lst[env_id] += visit_entropy_dict[env_id] @@ -575,10 +591,10 @@ def collect(self, # if game segment is full, we will save the last game segment if game_segments[env_id].is_full(): # pad over last segment trajectory - if last_game_segments[env_id] is not None: + if self.last_game_segments[env_id] is not None: # TODO(pu): return the one game segment self.pad_and_save_last_trajectory( - env_id, last_game_segments, last_game_priorities, game_segments, dones + env_id, self.last_game_segments, self.last_game_priorities, game_segments, self.dones ) # calculate priority @@ -589,8 +605,8 @@ def collect(self, improved_policy_lst[env_id] = [] # the current game_segments become last_game_segment - last_game_segments[env_id] = game_segments[env_id] - last_game_priorities[env_id] = priorities + self.last_game_segments[env_id] = game_segments[env_id] + self.last_game_priorities[env_id] = priorities # create new GameSegment game_segments[env_id] = GameSegment( @@ -628,9 +644,9 @@ def collect(self, # NOTE: put the penultimate game segment in one episode into the trajectory_pool # pad over 2th last game_segment using the last game_segment - if last_game_segments[env_id] is not None: + if self.last_game_segments[env_id] is not None: self.pad_and_save_last_trajectory( - env_id, last_game_segments, last_game_priorities, game_segments, dones + env_id, self.last_game_segments, self.last_game_priorities, game_segments, self.dones ) # store current segment trajectory @@ -642,51 +658,8 @@ def collect(self, # assert len(game_segments[env_id]) == len(priorities) # NOTE: save the last game segment in one episode into the trajectory_pool if it's not null if len(game_segments[env_id].reward_segment) != 0: - self.game_segment_pool.append((game_segments[env_id], priorities, dones[env_id])) - - # print(game_segments[env_id].reward_segment) - # reset the finished env and init game_segments - if n_episode > self._env_num: - # Get current ready env obs. - init_obs = self._env.ready_obs - retry_waiting_time = 0.001 - while len(init_obs.keys()) != self._env_num: - # In order to be compatible with subprocess env_manager, in which sometimes self._env_num is not equal to - # len(self._env.ready_obs), especially in tictactoe env. - self._logger.info('The current init_obs.keys() is {}'.format(init_obs.keys())) - self._logger.info('Before sleeping, the _env_states is {}'.format(self._env._env_states)) - time.sleep(retry_waiting_time) - self._logger.info( - '=' * 10 + 'Wait for all environments (subprocess) to finish resetting.' + '=' * 10 - ) - self._logger.info( - 'After sleeping {}s, the current _env_states is {}'.format( - retry_waiting_time, self._env._env_states - ) - ) - init_obs = self._env.ready_obs - - new_available_env_id = set(init_obs.keys()).difference(ready_env_id) - ready_env_id = ready_env_id.union(set(list(new_available_env_id)[:remain_episode])) - remain_episode -= min(len(new_available_env_id), remain_episode) - - action_mask_dict[env_id] = to_ndarray(init_obs[env_id]['action_mask']) - to_play_dict[env_id] = to_ndarray(init_obs[env_id]['to_play']) - if self.policy_config.use_ture_chance_label_in_chance_encoder: - chance_dict[env_id] = to_ndarray(init_obs[env_id]['chance']) + self.game_segment_pool.append((game_segments[env_id], priorities, self.dones[env_id])) - game_segments[env_id] = GameSegment( - self._env.action_space, - game_segment_length=self.policy_config.game_segment_length, - config=self.policy_config - ) - observation_window_stack[env_id] = deque( - [init_obs[env_id]['observation'] for _ in range(self.policy_config.model.frame_stack_num)], - maxlen=self.policy_config.model.frame_stack_num - ) - game_segments[env_id].reset(observation_window_stack[env_id]) - last_game_segments[env_id] = None - last_game_priorities[env_id] = None # log self_play_moves_max = max(self_play_moves_max, eps_steps_lst[env_id]) @@ -705,17 +678,13 @@ def collect(self, self._reset_stat(env_id) ready_env_id.remove(env_id) - # if collected_episode >= n_episode: - # if self.collected_game_segments >= self._default_num_segments or one_episode_done: # game_segment_length = 400 - # print(f'collect {self.collected_game_segments} segments now! one_episode_done: {one_episode_done}') - # self.collected_game_segments = 0 # 如果放到for循环里面去的v1版本,应该是丢失了部分环境的样本 # 下面的v2版本,是将个环境的样本都正确返回了 - if len(self.game_segment_pool) >= self._default_num_segments or one_episode_done: # game_segment_length = 400 + # if len(self.game_segment_pool) >= self._default_num_segments or one_episode_done: # game_segment_length = 400 + if len(self.game_segment_pool) >= self._default_num_segments: # game_segment_length = 400 print(f'collect {len(self.game_segment_pool)} segments now! one_episode_done: {one_episode_done}') collected_enough_segments = True # 条件满足,设置标志变量为 True - # one_episode_done = False # [data, meta_data] return_data = [self.game_segment_pool[i][0] for i in range(len(self.game_segment_pool))], [ diff --git a/lzero/worker/muzero_segment_collector_bkp.py b/lzero/worker/muzero_segment_collector_bkp.py new file mode 100644 index 000000000..b37bc203d --- /dev/null +++ b/lzero/worker/muzero_segment_collector_bkp.py @@ -0,0 +1,761 @@ +import time +from collections import deque, namedtuple +from typing import Optional, Any, List + +import numpy as np +import torch +from ding.envs import BaseEnvManager +from ding.torch_utils import to_ndarray +from ding.utils import build_logger, EasyTimer, SERIAL_COLLECTOR_REGISTRY, get_rank, get_world_size, \ + allreduce_data +from ding.worker.collector.base_serial_collector import ISerialCollector +from torch.nn import L1Loss + +from lzero.mcts.buffer.game_segment import GameSegment +from lzero.mcts.utils import prepare_observation + + +@SERIAL_COLLECTOR_REGISTRY.register('segment_muzero') +class MuZeroSegmentCollector(ISerialCollector): + """ + Overview: + The Collector for MCTS+RL algorithms, including MuZero, EfficientZero, Sampled EfficientZero, Gumbel MuZero. + It manages the data collection process for training these algorithms using a serial mechanism. + Interfaces: + ``__init__``, ``reset``, ``reset_env``, ``reset_policy``, ``_reset_stat``, ``envstep``, ``__del__``, ``_compute_priorities``, + ``pad_and_save_last_trajectory``, ``collect``, ``_output_log``, ``close`` + Properties: + ``envstep`` + """ + + # TO be compatible with ISerialCollector + config = dict() + + def __init__( + self, + collect_print_freq: int = 100, + env: BaseEnvManager = None, + policy: namedtuple = None, + tb_logger: 'SummaryWriter' = None, # noqa + exp_name: Optional[str] = 'default_experiment', + instance_name: Optional[str] = 'collector', + policy_config: 'policy_config' = None, # noqa + ) -> None: + """ + Overview: + Initialize the MuZeroCollector with the given parameters. + Arguments: + - collect_print_freq (:obj:`int`): Frequency (in training steps) at which to print collection information. + - env (:obj:`Optional[BaseEnvManager]`): Instance of the subclass of vectorized environment manager. + - policy (:obj:`Optional[namedtuple]`): namedtuple of the collection mode policy API. + - tb_logger (:obj:`Optional[SummaryWriter]`): TensorBoard logger instance. + - exp_name (:obj:`str`): Name of the experiment, used for logging and saving purposes. + - instance_name (:obj:`str`): Unique identifier for this collector instance. + - policy_config (:obj:`Optional[policy_config]`): Configuration object for the policy. + """ + self._exp_name = exp_name + self._instance_name = instance_name + self._collect_print_freq = collect_print_freq + self._timer = EasyTimer() + self._end_flag = False + + self._rank = get_rank() + self._world_size = get_world_size() + if self._rank == 0: + if tb_logger is not None: + self._logger, _ = build_logger( + path='./{}/log/{}'.format(self._exp_name, self._instance_name), + name=self._instance_name, + need_tb=False + ) + self._tb_logger = tb_logger + else: + self._logger, self._tb_logger = build_logger( + path='./{}/log/{}'.format(self._exp_name, self._instance_name), name=self._instance_name + ) + else: + self._logger, _ = build_logger( + path='./{}/log/{}'.format(self._exp_name, self._instance_name), name=self._instance_name, need_tb=False + ) + self._tb_logger = None + + self.policy_config = policy_config + self.collect_with_pure_policy = self.policy_config.collect_with_pure_policy + + self.reset(policy, env) + + def reset_env(self, _env: Optional[BaseEnvManager] = None) -> None: + """ + Overview: + Reset or replace the environment managed by this collector. + If _env is None, reset the old environment. + If _env is not None, replace the old environment in the collector with the new passed \ + in environment and launch. + Arguments: + - env (:obj:`Optional[BaseEnvManager]`): New environment to manage, if provided. + """ + if _env is not None: + self._env = _env + self._env.launch() + self._env_num = self._env.env_num + else: + self._env.reset() + + def reset_policy(self, _policy: Optional[namedtuple] = None) -> None: + """ + Overview: + Reset or replace the policy used by this collector. + If _policy is None, reset the old policy. + If _policy is not None, replace the old policy in the collector with the new passed in policy. + Arguments: + - policy (:obj:`Optional[namedtuple]`): the api namedtuple of collect_mode policy + """ + assert hasattr(self, '_env'), "please set env first" + if _policy is not None: + self._policy = _policy + self._default_n_episode = _policy.get_attribute('cfg').get('n_episode', None) + + self._default_num_segments = _policy.get_attribute('cfg').get('num_segments', None) + + self._logger.debug( + 'Set default n_episode mode(n_episode({}), env_num({}))'.format(self._default_n_episode, self._env_num) + ) + self._policy.reset() + + def reset(self, _policy: Optional[namedtuple] = None, _env: Optional[BaseEnvManager] = None) -> None: + """ + Overview: + Reset the collector with the given policy and/or environment. + If _env is None, reset the old environment. + If _env is not None, replace the old environment in the collector with the new passed \ + in environment and launch. + If _policy is None, reset the old policy. + If _policy is not None, replace the old policy in the collector with the new passed in policy. + Arguments: + - policy (:obj:`Optional[namedtuple]`): the api namedtuple of collect_mode policy + - env (:obj:`Optional[BaseEnvManager]`): instance of the subclass of vectorized \ + env_manager(BaseEnvManager) + """ + if _env is not None: + self.reset_env(_env) + if _policy is not None: + self.reset_policy(_policy) + + self._env_info = {env_id: {'time': 0., 'step': 0} for env_id in range(self._env_num)} + + # 在此处初始化action_mask_dict, to_play_dict和chance_dict,确保它们包含所有env_id的值 + self.action_mask_dict = {i: None for i in range(self._env_num)} + self.to_play_dict = {i: None for i in range(self._env_num)} + if self.policy_config.use_ture_chance_label_in_chance_encoder: + self.chance_dict = {i: None for i in range(self._env_num)} + + self._episode_info = [] + self._total_envstep_count = 0 + self._total_episode_count = 0 + self._total_duration = 0 + self._last_train_iter = 0 + self._end_flag = False + + # A game_segment_pool implementation based on the deque structure. + self.game_segment_pool = deque(maxlen=int(1e6)) + self.unroll_plus_td_steps = self.policy_config.num_unroll_steps + self.policy_config.td_steps + + def _reset_stat(self, env_id: int) -> None: + """ + Overview: + Reset the collector's state. Including reset the traj_buffer, obs_pool, policy_output_pool \ + and env_info. Reset these states according to env_id. You can refer to base_serial_collector\ + to get more messages. + Arguments: + - env_id (:obj:`int`): the id where we need to reset the collector's state + """ + self._env_info[env_id] = {'time': 0., 'step': 0} + + @property + def envstep(self) -> int: + """ + Overview: + Get the total number of environment steps collected. + Returns: + - envstep (:obj:`int`): Total number of environment steps collected. + """ + return self._total_envstep_count + + def close(self) -> None: + """ + Overview: + Close the collector. If end_flag is False, close the environment, flush the tb_logger \ + and close the tb_logger. + """ + if self._end_flag: + return + self._end_flag = True + self._env.close() + if self._tb_logger: + self._tb_logger.flush() + self._tb_logger.close() + + def __del__(self) -> None: + """ + Overview: + Execute the close command and close the collector. __del__ is automatically called to \ + destroy the collector instance when the collector finishes its work + """ + self.close() + + # ============================================================== + # MCTS+RL related core code + # ============================================================== + def _compute_priorities(self, i: int, pred_values_lst: List[float], search_values_lst: List[float]) -> np.ndarray: + """ + Overview: + Compute the priorities for transitions based on prediction and search value discrepancies. + Arguments: + - i (:obj:`int`): Index of the values in the list to compute the priority for. + - pred_values_lst (:obj:`List[float]`): List of predicted values. + - search_values_lst (:obj:`List[float]`): List of search values obtained from MCTS. + Returns: + - priorities (:obj:`np.ndarray`): Array of computed priorities. + """ + if self.policy_config.use_priority: + # Calculate priorities. The priorities are the L1 losses between the predicted + # values and the search values. We use 'none' as the reduction parameter, which + # means the loss is calculated for each element individually, instead of being summed or averaged. + # A small constant (1e-6) is added to the results to avoid zero priorities. This + # is done because zero priorities could potentially cause issues in some scenarios. + pred_values = torch.from_numpy(np.array(pred_values_lst[i])).to(self.policy_config.device).float().view(-1) + search_values = torch.from_numpy(np.array(search_values_lst[i])).to(self.policy_config.device + ).float().view(-1) + priorities = L1Loss(reduction='none' + )(pred_values, + search_values).detach().cpu().numpy() + 1e-6 + else: + # priorities is None -> use the max priority for all newly collected data + priorities = None + + return priorities + + def pad_and_save_last_trajectory(self, i: int, last_game_segments: List[GameSegment], + last_game_priorities: List[np.ndarray], + game_segments: List[GameSegment], done: np.ndarray) -> None: + """ + Overview: + Save the game segment to the pool if the current game is finished, padding it if necessary. + Arguments: + - i (:obj:`int`): Index of the current game segment. + - last_game_segments (:obj:`List[GameSegment]`): List of the last game segments to be padded and saved. + - last_game_priorities (:obj:`List[np.ndarray]`): List of priorities of the last game segments. + - game_segments (:obj:`List[GameSegment]`): List of the current game segments. + - done (:obj:`np.ndarray`): Array indicating whether each game is done. + Note: + (last_game_segments[i].obs_segment[-4:][j] == game_segments[i].obs_segment[:4][j]).all() is True + """ + # pad over last segment trajectory + beg_index = self.policy_config.model.frame_stack_num + end_index = beg_index + self.policy_config.num_unroll_steps + + # the start obs is init zero obs, so we take the + # [ : +] obs as the pad obs + # e.g. the start 4 obs is init zero obs, the num_unroll_steps is 5, so we take the [4:9] obs as the pad obs + pad_obs_lst = game_segments[i].obs_segment[beg_index:end_index] + pad_child_visits_lst = game_segments[i].child_visit_segment[:self.policy_config.num_unroll_steps] + # EfficientZero original repo bug: + # pad_child_visits_lst = game_segments[i].child_visit_segment[beg_index:end_index] + + beg_index = 0 + # self.unroll_plus_td_steps = self.policy_config.num_unroll_steps + self.policy_config.td_steps + end_index = beg_index + self.unroll_plus_td_steps - 1 + + pad_reward_lst = game_segments[i].reward_segment[beg_index:end_index] + if self.policy_config.use_ture_chance_label_in_chance_encoder: + chance_lst = game_segments[i].chance_segment[beg_index:end_index] + + beg_index = 0 + end_index = beg_index + self.unroll_plus_td_steps + + pad_root_values_lst = game_segments[i].root_value_segment[beg_index:end_index] + + if self.policy_config.gumbel_algo: + pad_improved_policy_prob = game_segments[i].improved_policy_probs[beg_index:end_index] + + # pad over and save + if self.policy_config.gumbel_algo: + last_game_segments[i].pad_over(pad_obs_lst, pad_reward_lst, pad_root_values_lst, pad_child_visits_lst, + next_segment_improved_policy=pad_improved_policy_prob) + else: + if self.policy_config.use_ture_chance_label_in_chance_encoder: + last_game_segments[i].pad_over(pad_obs_lst, pad_reward_lst, pad_root_values_lst, pad_child_visits_lst, + next_chances=chance_lst) + else: + last_game_segments[i].pad_over(pad_obs_lst, pad_reward_lst, pad_root_values_lst, pad_child_visits_lst) + """ + Note: + game_segment element shape: + obs: game_segment_length + stack + num_unroll_steps, 20+4 +5 + rew: game_segment_length + stack + num_unroll_steps + td_steps -1 20 +5+3-1 + action: game_segment_length -> 20 + root_values: game_segment_length + num_unroll_steps + td_steps -> 20 +5+3 + child_visits: game_segment_length + num_unroll_steps -> 20 +5 + to_play: game_segment_length -> 20 + action_mask: game_segment_length -> 20 + """ + + last_game_segments[i].game_segment_to_array() + + self.collected_game_segments += 1 + + # put the game segment into the pool + self.game_segment_pool.append((last_game_segments[i], last_game_priorities[i], done[i])) + + # reset last game_segments + last_game_segments[i] = None + last_game_priorities[i] = None + + + return None + + def collect(self, + n_episode: Optional[int] = None, + train_iter: int = 0, + policy_kwargs: Optional[dict] = None, + collect_with_pure_policy: bool = False) -> List[Any]: + """ + Overview: + Collect `n_episode` episodes of data with policy_kwargs, trained for `train_iter` iterations. + Arguments: + - n_episode (:obj:`Optional[int]`): Number of episodes to collect. + - train_iter (:obj:`int`): Number of training iterations completed so far. + - policy_kwargs (:obj:`Optional[dict]`): Additional keyword arguments for the policy. + - collect_with_pure_policy (:obj:`bool`): Whether to collect data using pure policy without MCTS. + Returns: + - return_data (:obj:`List[Any]`): Collected data in the form of a list. + """ + # TODO: collect_with_pure_policy as a separate collector + if n_episode is None: + if self._default_n_episode is None: + raise RuntimeError("Please specify collect n_episode") + else: + n_episode = self._default_n_episode + assert n_episode >= self._env_num, "Please make sure n_episode >= env_num{}/{}".format(n_episode, self._env_num) + if policy_kwargs is None: + policy_kwargs = {} + temperature = policy_kwargs['temperature'] + epsilon = policy_kwargs['epsilon'] + + collected_episode = 0 + collected_step = 0 + env_nums = self._env_num + + # initializations + init_obs = self._env.ready_obs + + # retry_waiting_time = 0.05 + # while len(init_obs.keys()) != self._env_num: + # # To be compatible with subprocess env_manager, in which sometimes self._env_num is not equal to + # # len(self._env.ready_obs), especially in tictactoe env. + # self._logger.info('The current init_obs.keys() is {}'.format(init_obs.keys())) + # self._logger.info('Before sleeping, the _env_states is {}'.format(self._env._env_states)) + # time.sleep(retry_waiting_time) + # self._logger.info('=' * 10 + 'Wait for all environments (subprocess) to finish resetting.' + '=' * 10) + # self._logger.info( + # 'After sleeping {}s, the current _env_states is {}'.format(retry_waiting_time, self._env._env_states) + # ) + # init_obs = self._env.ready_obs + + # action_mask_dict = {i: to_ndarray(init_obs[i]['action_mask']) for i in range(env_nums)} + # to_play_dict = {i: to_ndarray(init_obs[i]['to_play']) for i in range(env_nums)} + # if self.policy_config.use_ture_chance_label_in_chance_encoder: + # chance_dict = {i: to_ndarray(init_obs[i]['chance']) for i in range(env_nums)} + + # 改为直接使用self.action_mask_dict等变量 + for env_id in range(env_nums): + self.action_mask_dict[env_id] = to_ndarray(init_obs[env_id]['action_mask']) + self.to_play_dict[env_id] = to_ndarray(init_obs[env_id]['to_play']) + if self.policy_config.use_ture_chance_label_in_chance_encoder: + self.chance_dict[env_id] = to_ndarray(init_obs[env_id]['chance']) + + + game_segments = [ + GameSegment( + self._env.action_space, + game_segment_length=self.policy_config.game_segment_length, + config=self.policy_config + ) for _ in range(env_nums) + ] + # stacked observation windows in reset stage for init game_segments + observation_window_stack = [[] for _ in range(env_nums)] + for env_id in range(env_nums): + observation_window_stack[env_id] = deque( + [to_ndarray(init_obs[env_id]['observation']) for _ in range(self.policy_config.model.frame_stack_num)], + maxlen=self.policy_config.model.frame_stack_num + ) + + game_segments[env_id].reset(observation_window_stack[env_id]) + + dones = np.array([False for _ in range(env_nums)]) + last_game_segments = [None for _ in range(env_nums)] + last_game_priorities = [None for _ in range(env_nums)] + + # for priorities in self-play + search_values_lst = [[] for _ in range(env_nums)] + pred_values_lst = [[] for _ in range(env_nums)] + if self.policy_config.gumbel_algo: + improved_policy_lst = [[] for _ in range(env_nums)] + + # some logs + eps_steps_lst, visit_entropies_lst = np.zeros(env_nums), np.zeros(env_nums) + if self.policy_config.gumbel_algo: + completed_value_lst = np.zeros(env_nums) + self_play_moves = 0. + self_play_episodes = 0. + self_play_moves_max = 0 + self_play_visit_entropy = [] + total_transitions = 0 + + if collect_with_pure_policy: + temp_visit_list = [0.0 for i in range(self._env.action_space.n)] + + self.collected_game_segments = 0 + one_episode_done = False + collected_enough_segments = False # 标志变量 + + while True: + with self._timer: + # Get current ready env obs. + obs = self._env.ready_obs + ready_env_id = set(obs.keys()) + + + stack_obs = {env_id: game_segments[env_id].get_obs() for env_id in ready_env_id} + stack_obs = list(stack_obs.values()) + + self.action_mask_dict = {env_id: self.action_mask_dict[env_id] for env_id in ready_env_id} + self.to_play_dict = {env_id: self.to_play_dict[env_id] for env_id in ready_env_id} + action_mask = [self.action_mask_dict[env_id] for env_id in ready_env_id] + to_play = [self.to_play_dict[env_id] for env_id in ready_env_id] + if self.policy_config.use_ture_chance_label_in_chance_encoder: + self.chance_dict = {env_id: self.chance_dict[env_id] for env_id in ready_env_id} + + stack_obs = to_ndarray(stack_obs) + # return stack_obs shape: [B, S*C, W, H] e.g. [8, 4*1, 96, 96] + stack_obs = prepare_observation(stack_obs, self.policy_config.model.model_type) + stack_obs = torch.from_numpy(stack_obs).to(self.policy_config.device) + + # ============================================================== + # Key policy forward step + # ============================================================== + # print(f'ready_env_id:{ready_env_id}') + policy_output = self._policy.forward(stack_obs, action_mask, temperature, to_play, epsilon, ready_env_id=ready_env_id) + + # Extract relevant policy outputs + actions_with_env_id = {k: v['action'] for k, v in policy_output.items()} + value_dict_with_env_id = {k: v['searched_value'] for k, v in policy_output.items()} + pred_value_dict_with_env_id = {k: v['predicted_value'] for k, v in policy_output.items()} + + if self.policy_config.sampled_algo: + root_sampled_actions_dict_with_env_id = { + k: v['root_sampled_actions'] for k, v in policy_output.items() + } + + if not collect_with_pure_policy: + distributions_dict_with_env_id = {k: v['visit_count_distributions'] for k, v in + policy_output.items()} + visit_entropy_dict_with_env_id = {k: v['visit_count_distribution_entropy'] for k, v in + policy_output.items()} + + if self.policy_config.gumbel_algo: + improved_policy_dict_with_env_id = {k: v['improved_policy_probs'] for k, v in + policy_output.items()} + completed_value_with_env_id = {k: v['roots_completed_value'] for k, v in policy_output.items()} + + # Initialize dictionaries to store results + actions = {} + value_dict = {} + pred_value_dict = {} + + if not collect_with_pure_policy: + distributions_dict = {} + visit_entropy_dict = {} + + if self.policy_config.sampled_algo: + root_sampled_actions_dict = {} + + if self.policy_config.gumbel_algo: + improved_policy_dict = {} + completed_value_dict = {} + + # Populate the result dictionaries + for env_id in ready_env_id: + actions[env_id] = actions_with_env_id.pop(env_id) + value_dict[env_id] = value_dict_with_env_id.pop(env_id) + pred_value_dict[env_id] = pred_value_dict_with_env_id.pop(env_id) + + if not collect_with_pure_policy: + distributions_dict[env_id] = distributions_dict_with_env_id.pop(env_id) + + if self.policy_config.sampled_algo: + root_sampled_actions_dict[env_id] = root_sampled_actions_dict_with_env_id.pop(env_id) + + visit_entropy_dict[env_id] = visit_entropy_dict_with_env_id.pop(env_id) + + if self.policy_config.gumbel_algo: + improved_policy_dict[env_id] = improved_policy_dict_with_env_id.pop(env_id) + completed_value_dict[env_id] = completed_value_with_env_id.pop(env_id) + + # ============================================================== + # Interact with the environment + # ============================================================== + timesteps = self._env.step(actions) + + interaction_duration = self._timer.value / len(timesteps) + + for env_id, timestep in timesteps.items(): + with self._timer: + if timestep.info.get('abnormal', False): + # If there is an abnormal timestep, reset all the related variables(including this env). + # suppose there is no reset param, reset this env + self._env.reset({env_id: None}) + self._policy.reset([env_id]) + self._reset_stat(env_id) + self._logger.info('Env{} returns a abnormal step, its info is {}'.format(env_id, timestep.info)) + continue + obs, reward, done, info = timestep.obs, timestep.reward, timestep.done, timestep.info + + if collect_with_pure_policy: + game_segments[env_id].store_search_stats(temp_visit_list, 0) + else: + if self.policy_config.sampled_algo: + game_segments[env_id].store_search_stats( + distributions_dict[env_id], value_dict[env_id], root_sampled_actions_dict[env_id] + ) + elif self.policy_config.gumbel_algo: + game_segments[env_id].store_search_stats(distributions_dict[env_id], value_dict[env_id], + improved_policy=improved_policy_dict[env_id]) + else: + game_segments[env_id].store_search_stats(distributions_dict[env_id], value_dict[env_id]) + + # append a transition tuple, including a_t, o_{t+1}, r_{t}, action_mask_{t}, to_play_{t} + # in ``game_segments[env_id].init``, we have appended o_{t} in ``self.obs_segment`` + if self.policy_config.use_ture_chance_label_in_chance_encoder: + game_segments[env_id].append( + actions[env_id], to_ndarray(obs['observation']), reward, self.action_mask_dict[env_id], + self.to_play_dict[env_id], self.chance_dict[env_id] + ) + else: + game_segments[env_id].append( + actions[env_id], to_ndarray(obs['observation']), reward, self.action_mask_dict[env_id], + self.to_play_dict[env_id] + ) + + # NOTE: the position of code snippet is very important. + # the obs['action_mask'] and obs['to_play'] are corresponding to the next action + self.action_mask_dict[env_id] = to_ndarray(obs['action_mask']) + self.to_play_dict[env_id] = to_ndarray(obs['to_play']) + if self.policy_config.use_ture_chance_label_in_chance_encoder: + self.chance_dict[env_id] = to_ndarray(obs['chance']) + + if self.policy_config.ignore_done: + dones[env_id] = False + else: + dones[env_id] = done + + if not collect_with_pure_policy: + visit_entropies_lst[env_id] += visit_entropy_dict[env_id] + if self.policy_config.gumbel_algo: + completed_value_lst[env_id] += np.mean(np.array(completed_value_dict[env_id])) + + eps_steps_lst[env_id] += 1 + if self._policy.get_attribute('cfg').type == 'unizero': + # only for UniZero now + self._policy.reset(env_id=env_id, current_steps=eps_steps_lst[env_id], reset_init_data=False) + + total_transitions += 1 + + if self.policy_config.use_priority: + pred_values_lst[env_id].append(pred_value_dict[env_id]) + search_values_lst[env_id].append(value_dict[env_id]) + if self.policy_config.gumbel_algo and not collect_with_pure_policy: + improved_policy_lst[env_id].append(improved_policy_dict[env_id]) + + # append the newest obs + observation_window_stack[env_id].append(to_ndarray(obs['observation'])) + + # ============================================================== + # we will save a game segment if it is the end of the game or the next game segment is finished. + # ============================================================== + + # if game segment is full, we will save the last game segment + if game_segments[env_id].is_full(): + # pad over last segment trajectory + if last_game_segments[env_id] is not None: + # TODO(pu): return the one game segment + self.pad_and_save_last_trajectory( + env_id, last_game_segments, last_game_priorities, game_segments, dones + ) + + # calculate priority + priorities = self._compute_priorities(env_id, pred_values_lst, search_values_lst) + pred_values_lst[env_id] = [] + search_values_lst[env_id] = [] + if self.policy_config.gumbel_algo and not collect_with_pure_policy: + improved_policy_lst[env_id] = [] + + # the current game_segments become last_game_segment + last_game_segments[env_id] = game_segments[env_id] + last_game_priorities[env_id] = priorities + + # create new GameSegment + game_segments[env_id] = GameSegment( + self._env.action_space, + game_segment_length=self.policy_config.game_segment_length, + config=self.policy_config + ) + game_segments[env_id].reset(observation_window_stack[env_id]) + + self._env_info[env_id]['step'] += 1 + collected_step += 1 + + self._env_info[env_id]['time'] += self._timer.value + interaction_duration + if timestep.done: + one_episode_done = True + self._total_episode_count += 1 + + reward = timestep.info['eval_episode_return'] + info = { + 'reward': reward, + 'time': self._env_info[env_id]['time'], + 'step': self._env_info[env_id]['step'], + } + if not collect_with_pure_policy: + info['visit_entropy'] = visit_entropies_lst[env_id] / eps_steps_lst[env_id] + if self.policy_config.gumbel_algo: + info['completed_value'] = completed_value_lst[env_id] / eps_steps_lst[env_id] + + collected_episode += 1 + self._episode_info.append(info) + + # ============================================================== + # if it is the end of the game, we will save the game segment + # ============================================================== + + # NOTE: put the penultimate game segment in one episode into the trajectory_pool + # pad over 2th last game_segment using the last game_segment + if last_game_segments[env_id] is not None: + self.pad_and_save_last_trajectory( + env_id, last_game_segments, last_game_priorities, game_segments, dones + ) + + # store current segment trajectory + priorities = self._compute_priorities(env_id, pred_values_lst, search_values_lst) + + # NOTE: put the last game segment in one episode into the trajectory_pool + game_segments[env_id].game_segment_to_array() + + # assert len(game_segments[env_id]) == len(priorities) + # NOTE: save the last game segment in one episode into the trajectory_pool if it's not null + if len(game_segments[env_id].reward_segment) != 0: + self.game_segment_pool.append((game_segments[env_id], priorities, dones[env_id])) + + + # log + self_play_moves_max = max(self_play_moves_max, eps_steps_lst[env_id]) + if not collect_with_pure_policy: + self_play_visit_entropy.append(visit_entropies_lst[env_id] / eps_steps_lst[env_id]) + self_play_moves += eps_steps_lst[env_id] + self_play_episodes += 1 + + pred_values_lst[env_id] = [] + search_values_lst[env_id] = [] + eps_steps_lst[env_id] = 0 + visit_entropies_lst[env_id] = 0 + + # Env reset is done by env_manager automatically + self._policy.reset([env_id]) # NOTE: reset the policy for the env_id. Default reset_init_data=True. + self._reset_stat(env_id) + ready_env_id.remove(env_id) + + + # 如果放到for循环里面去的v1版本,应该是丢失了部分环境的样本 + # 下面的v2版本,是将个环境的样本都正确返回了 + # if len(self.game_segment_pool) >= self._default_num_segments or one_episode_done: # game_segment_length = 400 + if len(self.game_segment_pool) >= self._default_num_segments: # game_segment_length = 400 + print(f'collect {len(self.game_segment_pool)} segments now! one_episode_done: {one_episode_done}') + collected_enough_segments = True # 条件满足,设置标志变量为 True + + # [data, meta_data] + return_data = [self.game_segment_pool[i][0] for i in range(len(self.game_segment_pool))], [ + { + 'priorities': self.game_segment_pool[i][1], + 'done': self.game_segment_pool[i][2], + 'unroll_plus_td_steps': self.unroll_plus_td_steps + } for i in range(len(self.game_segment_pool)) + ] + self.game_segment_pool.clear() + break + + # if collected_enough_segments: # 如果满足条件,则跳出外层 for 循环 + # break + + collected_duration = sum([d['time'] for d in self._episode_info]) + # reduce data when enables DDP + if self._world_size > 1: + collected_step = allreduce_data(collected_step, 'sum') + collected_episode = allreduce_data(collected_episode, 'sum') + collected_duration = allreduce_data(collected_duration, 'sum') + self._total_envstep_count += collected_step + self._total_episode_count += collected_episode + self._total_duration += collected_duration + + # log + self._output_log(train_iter) + return return_data + + def _output_log(self, train_iter: int) -> None: + """ + Overview: + Log the collector's data and output the log information. + Arguments: + - train_iter (:obj:`int`): Current training iteration number for logging context. + """ + if self._rank != 0: + return + if (train_iter - self._last_train_iter) >= self._collect_print_freq and len(self._episode_info) > 0: + self._last_train_iter = train_iter + episode_count = len(self._episode_info) + envstep_count = sum([d['step'] for d in self._episode_info]) + duration = sum([d['time'] for d in self._episode_info]) + episode_reward = [d['reward'] for d in self._episode_info] + if not self.collect_with_pure_policy: + visit_entropy = [d['visit_entropy'] for d in self._episode_info] + else: + visit_entropy = [0.0] + if self.policy_config.gumbel_algo: + completed_value = [d['completed_value'] for d in self._episode_info] + self._total_duration += duration + info = { + 'episode_count': episode_count, + 'envstep_count': envstep_count, + 'avg_envstep_per_episode': envstep_count / episode_count, + 'avg_envstep_per_sec': envstep_count / duration, + 'avg_episode_per_sec': episode_count / duration, + 'collect_time': duration, + 'reward_mean': np.mean(episode_reward), + 'reward_std': np.std(episode_reward), + 'reward_max': np.max(episode_reward), + 'reward_min': np.min(episode_reward), + 'total_envstep_count': self._total_envstep_count, + 'total_episode_count': self._total_episode_count, + 'total_duration': self._total_duration, + 'visit_entropy': np.mean(visit_entropy), + } + if self.policy_config.gumbel_algo: + info['completed_value'] = np.mean(completed_value) + self._episode_info.clear() + self._logger.info("collect end:\n{}".format('\n'.join(['{}: {}'.format(k, v) for k, v in info.items()]))) + for k, v in info.items(): + if k in ['each_reward']: + continue + self._tb_logger.add_scalar('{}_iter/'.format(self._instance_name) + k, v, train_iter) + if k in ['total_envstep_count']: + continue + self._tb_logger.add_scalar('{}_step/'.format(self._instance_name) + k, v, self._total_envstep_count) \ No newline at end of file diff --git a/lzero/worker/muzero_segment_collector_v1.py b/lzero/worker/muzero_segment_collector_v1.py new file mode 100644 index 000000000..b6821ef9b --- /dev/null +++ b/lzero/worker/muzero_segment_collector_v1.py @@ -0,0 +1,796 @@ +import time +from collections import deque, namedtuple +from typing import Optional, Any, List + +import numpy as np +import torch +from ding.envs import BaseEnvManager +from ding.torch_utils import to_ndarray +from ding.utils import build_logger, EasyTimer, SERIAL_COLLECTOR_REGISTRY, get_rank, get_world_size, \ + allreduce_data +from ding.worker.collector.base_serial_collector import ISerialCollector +from torch.nn import L1Loss + +from lzero.mcts.buffer.game_segment import GameSegment +from lzero.mcts.utils import prepare_observation + + +@SERIAL_COLLECTOR_REGISTRY.register('segment_muzero') +class MuZeroSegmentCollector(ISerialCollector): + """ + Overview: + The Collector for MCTS+RL algorithms, including MuZero, EfficientZero, Sampled EfficientZero, Gumbel MuZero. + It manages the data collection process for training these algorithms using a serial mechanism. + Interfaces: + ``__init__``, ``reset``, ``reset_env``, ``reset_policy``, ``_reset_stat``, ``envstep``, ``__del__``, ``_compute_priorities``, + ``pad_and_save_last_trajectory``, ``collect``, ``_output_log``, ``close`` + Properties: + ``envstep`` + """ + + # TO be compatible with ISerialCollector + config = dict() + + def __init__( + self, + collect_print_freq: int = 100, + env: BaseEnvManager = None, + policy: namedtuple = None, + tb_logger: 'SummaryWriter' = None, # noqa + exp_name: Optional[str] = 'default_experiment', + instance_name: Optional[str] = 'collector', + policy_config: 'policy_config' = None, # noqa + ) -> None: + """ + Overview: + Initialize the MuZeroCollector with the given parameters. + Arguments: + - collect_print_freq (:obj:`int`): Frequency (in training steps) at which to print collection information. + - env (:obj:`Optional[BaseEnvManager]`): Instance of the subclass of vectorized environment manager. + - policy (:obj:`Optional[namedtuple]`): namedtuple of the collection mode policy API. + - tb_logger (:obj:`Optional[SummaryWriter]`): TensorBoard logger instance. + - exp_name (:obj:`str`): Name of the experiment, used for logging and saving purposes. + - instance_name (:obj:`str`): Unique identifier for this collector instance. + - policy_config (:obj:`Optional[policy_config]`): Configuration object for the policy. + """ + self._exp_name = exp_name + self._instance_name = instance_name + self._collect_print_freq = collect_print_freq + self._timer = EasyTimer() + self._end_flag = False + + self._rank = get_rank() + self._world_size = get_world_size() + if self._rank == 0: + if tb_logger is not None: + self._logger, _ = build_logger( + path='./{}/log/{}'.format(self._exp_name, self._instance_name), + name=self._instance_name, + need_tb=False + ) + self._tb_logger = tb_logger + else: + self._logger, self._tb_logger = build_logger( + path='./{}/log/{}'.format(self._exp_name, self._instance_name), name=self._instance_name + ) + else: + self._logger, _ = build_logger( + path='./{}/log/{}'.format(self._exp_name, self._instance_name), name=self._instance_name, need_tb=False + ) + self._tb_logger = None + + self.policy_config = policy_config + self.collect_with_pure_policy = self.policy_config.collect_with_pure_policy + + self.reset(policy, env) + + def reset_env(self, _env: Optional[BaseEnvManager] = None) -> None: + """ + Overview: + Reset or replace the environment managed by this collector. + If _env is None, reset the old environment. + If _env is not None, replace the old environment in the collector with the new passed \ + in environment and launch. + Arguments: + - env (:obj:`Optional[BaseEnvManager]`): New environment to manage, if provided. + """ + if _env is not None: + self._env = _env + self._env.launch() + self._env_num = self._env.env_num + else: + self._env.reset() + + def reset_policy(self, _policy: Optional[namedtuple] = None) -> None: + """ + Overview: + Reset or replace the policy used by this collector. + If _policy is None, reset the old policy. + If _policy is not None, replace the old policy in the collector with the new passed in policy. + Arguments: + - policy (:obj:`Optional[namedtuple]`): the api namedtuple of collect_mode policy + """ + assert hasattr(self, '_env'), "please set env first" + if _policy is not None: + self._policy = _policy + self._default_n_episode = _policy.get_attribute('cfg').get('n_episode', None) + + self._default_num_segments = _policy.get_attribute('cfg').get('num_segments', None) + + self._logger.debug( + 'Set default n_episode mode(n_episode({}), env_num({}))'.format(self._default_n_episode, self._env_num) + ) + self._policy.reset() + + def reset(self, _policy: Optional[namedtuple] = None, _env: Optional[BaseEnvManager] = None) -> None: + """ + Overview: + Reset the collector with the given policy and/or environment. + If _env is None, reset the old environment. + If _env is not None, replace the old environment in the collector with the new passed \ + in environment and launch. + If _policy is None, reset the old policy. + If _policy is not None, replace the old policy in the collector with the new passed in policy. + Arguments: + - policy (:obj:`Optional[namedtuple]`): the api namedtuple of collect_mode policy + - env (:obj:`Optional[BaseEnvManager]`): instance of the subclass of vectorized \ + env_manager(BaseEnvManager) + """ + if _env is not None: + self.reset_env(_env) + if _policy is not None: + self.reset_policy(_policy) + + self._env_info = {env_id: {'time': 0., 'step': 0} for env_id in range(self._env_num)} + + self._episode_info = [] + self._total_envstep_count = 0 + self._total_episode_count = 0 + self._total_duration = 0 + self._last_train_iter = 0 + self._end_flag = False + + # A game_segment_pool implementation based on the deque structure. + self.game_segment_pool = deque(maxlen=int(1e6)) + self.unroll_plus_td_steps = self.policy_config.num_unroll_steps + self.policy_config.td_steps + + def _reset_stat(self, env_id: int) -> None: + """ + Overview: + Reset the collector's state. Including reset the traj_buffer, obs_pool, policy_output_pool \ + and env_info. Reset these states according to env_id. You can refer to base_serial_collector\ + to get more messages. + Arguments: + - env_id (:obj:`int`): the id where we need to reset the collector's state + """ + self._env_info[env_id] = {'time': 0., 'step': 0} + + @property + def envstep(self) -> int: + """ + Overview: + Get the total number of environment steps collected. + Returns: + - envstep (:obj:`int`): Total number of environment steps collected. + """ + return self._total_envstep_count + + def close(self) -> None: + """ + Overview: + Close the collector. If end_flag is False, close the environment, flush the tb_logger \ + and close the tb_logger. + """ + if self._end_flag: + return + self._end_flag = True + self._env.close() + if self._tb_logger: + self._tb_logger.flush() + self._tb_logger.close() + + def __del__(self) -> None: + """ + Overview: + Execute the close command and close the collector. __del__ is automatically called to \ + destroy the collector instance when the collector finishes its work + """ + self.close() + + # ============================================================== + # MCTS+RL related core code + # ============================================================== + def _compute_priorities(self, i: int, pred_values_lst: List[float], search_values_lst: List[float]) -> np.ndarray: + """ + Overview: + Compute the priorities for transitions based on prediction and search value discrepancies. + Arguments: + - i (:obj:`int`): Index of the values in the list to compute the priority for. + - pred_values_lst (:obj:`List[float]`): List of predicted values. + - search_values_lst (:obj:`List[float]`): List of search values obtained from MCTS. + Returns: + - priorities (:obj:`np.ndarray`): Array of computed priorities. + """ + if self.policy_config.use_priority: + # Calculate priorities. The priorities are the L1 losses between the predicted + # values and the search values. We use 'none' as the reduction parameter, which + # means the loss is calculated for each element individually, instead of being summed or averaged. + # A small constant (1e-6) is added to the results to avoid zero priorities. This + # is done because zero priorities could potentially cause issues in some scenarios. + pred_values = torch.from_numpy(np.array(pred_values_lst[i])).to(self.policy_config.device).float().view(-1) + search_values = torch.from_numpy(np.array(search_values_lst[i])).to(self.policy_config.device + ).float().view(-1) + priorities = L1Loss(reduction='none' + )(pred_values, + search_values).detach().cpu().numpy() + 1e-6 + else: + # priorities is None -> use the max priority for all newly collected data + priorities = None + + return priorities + + def pad_and_save_last_trajectory(self, i: int, last_game_segments: List[GameSegment], + last_game_priorities: List[np.ndarray], + game_segments: List[GameSegment], done: np.ndarray) -> None: + """ + Overview: + Save the game segment to the pool if the current game is finished, padding it if necessary. + Arguments: + - i (:obj:`int`): Index of the current game segment. + - last_game_segments (:obj:`List[GameSegment]`): List of the last game segments to be padded and saved. + - last_game_priorities (:obj:`List[np.ndarray]`): List of priorities of the last game segments. + - game_segments (:obj:`List[GameSegment]`): List of the current game segments. + - done (:obj:`np.ndarray`): Array indicating whether each game is done. + Note: + (last_game_segments[i].obs_segment[-4:][j] == game_segments[i].obs_segment[:4][j]).all() is True + """ + # pad over last segment trajectory + beg_index = self.policy_config.model.frame_stack_num + end_index = beg_index + self.policy_config.num_unroll_steps + + # the start obs is init zero obs, so we take the + # [ : +] obs as the pad obs + # e.g. the start 4 obs is init zero obs, the num_unroll_steps is 5, so we take the [4:9] obs as the pad obs + pad_obs_lst = game_segments[i].obs_segment[beg_index:end_index] + pad_child_visits_lst = game_segments[i].child_visit_segment[:self.policy_config.num_unroll_steps] + # EfficientZero original repo bug: + # pad_child_visits_lst = game_segments[i].child_visit_segment[beg_index:end_index] + + beg_index = 0 + # self.unroll_plus_td_steps = self.policy_config.num_unroll_steps + self.policy_config.td_steps + end_index = beg_index + self.unroll_plus_td_steps - 1 + + pad_reward_lst = game_segments[i].reward_segment[beg_index:end_index] + if self.policy_config.use_ture_chance_label_in_chance_encoder: + chance_lst = game_segments[i].chance_segment[beg_index:end_index] + + beg_index = 0 + end_index = beg_index + self.unroll_plus_td_steps + + pad_root_values_lst = game_segments[i].root_value_segment[beg_index:end_index] + + if self.policy_config.gumbel_algo: + pad_improved_policy_prob = game_segments[i].improved_policy_probs[beg_index:end_index] + + # pad over and save + if self.policy_config.gumbel_algo: + last_game_segments[i].pad_over(pad_obs_lst, pad_reward_lst, pad_root_values_lst, pad_child_visits_lst, + next_segment_improved_policy=pad_improved_policy_prob) + else: + if self.policy_config.use_ture_chance_label_in_chance_encoder: + last_game_segments[i].pad_over(pad_obs_lst, pad_reward_lst, pad_root_values_lst, pad_child_visits_lst, + next_chances=chance_lst) + else: + last_game_segments[i].pad_over(pad_obs_lst, pad_reward_lst, pad_root_values_lst, pad_child_visits_lst) + """ + Note: + game_segment element shape: + obs: game_segment_length + stack + num_unroll_steps, 20+4 +5 + rew: game_segment_length + stack + num_unroll_steps + td_steps -1 20 +5+3-1 + action: game_segment_length -> 20 + root_values: game_segment_length + num_unroll_steps + td_steps -> 20 +5+3 + child_visits: game_segment_length + num_unroll_steps -> 20 +5 + to_play: game_segment_length -> 20 + action_mask: game_segment_length -> 20 + """ + + last_game_segments[i].game_segment_to_array() + + self.collected_game_segments += 1 + + # put the game segment into the pool + self.game_segment_pool.append((last_game_segments[i], last_game_priorities[i], done[i])) + + # reset last game_segments + last_game_segments[i] = None + last_game_priorities[i] = None + + + return None + + def collect(self, + n_episode: Optional[int] = None, + train_iter: int = 0, + policy_kwargs: Optional[dict] = None, + collect_with_pure_policy: bool = False) -> List[Any]: + """ + Overview: + Collect `n_episode` episodes of data with policy_kwargs, trained for `train_iter` iterations. + Arguments: + - n_episode (:obj:`Optional[int]`): Number of episodes to collect. + - train_iter (:obj:`int`): Number of training iterations completed so far. + - policy_kwargs (:obj:`Optional[dict]`): Additional keyword arguments for the policy. + - collect_with_pure_policy (:obj:`bool`): Whether to collect data using pure policy without MCTS. + Returns: + - return_data (:obj:`List[Any]`): Collected data in the form of a list. + """ + # TODO: collect_with_pure_policy as a separate collector + if n_episode is None: + if self._default_n_episode is None: + raise RuntimeError("Please specify collect n_episode") + else: + n_episode = self._default_n_episode + assert n_episode >= self._env_num, "Please make sure n_episode >= env_num{}/{}".format(n_episode, self._env_num) + if policy_kwargs is None: + policy_kwargs = {} + temperature = policy_kwargs['temperature'] + epsilon = policy_kwargs['epsilon'] + + collected_episode = 0 + collected_step = 0 + env_nums = self._env_num + + # initializations + init_obs = self._env.ready_obs + + retry_waiting_time = 0.05 + while len(init_obs.keys()) != self._env_num: + # To be compatible with subprocess env_manager, in which sometimes self._env_num is not equal to + # len(self._env.ready_obs), especially in tictactoe env. + self._logger.info('The current init_obs.keys() is {}'.format(init_obs.keys())) + self._logger.info('Before sleeping, the _env_states is {}'.format(self._env._env_states)) + time.sleep(retry_waiting_time) + self._logger.info('=' * 10 + 'Wait for all environments (subprocess) to finish resetting.' + '=' * 10) + self._logger.info( + 'After sleeping {}s, the current _env_states is {}'.format(retry_waiting_time, self._env._env_states) + ) + init_obs = self._env.ready_obs + + action_mask_dict = {i: to_ndarray(init_obs[i]['action_mask']) for i in range(env_nums)} + to_play_dict = {i: to_ndarray(init_obs[i]['to_play']) for i in range(env_nums)} + if self.policy_config.use_ture_chance_label_in_chance_encoder: + chance_dict = {i: to_ndarray(init_obs[i]['chance']) for i in range(env_nums)} + + game_segments = [ + GameSegment( + self._env.action_space, + game_segment_length=self.policy_config.game_segment_length, + config=self.policy_config + ) for _ in range(env_nums) + ] + # stacked observation windows in reset stage for init game_segments + observation_window_stack = [[] for _ in range(env_nums)] + for env_id in range(env_nums): + observation_window_stack[env_id] = deque( + [to_ndarray(init_obs[env_id]['observation']) for _ in range(self.policy_config.model.frame_stack_num)], + maxlen=self.policy_config.model.frame_stack_num + ) + + game_segments[env_id].reset(observation_window_stack[env_id]) + + dones = np.array([False for _ in range(env_nums)]) + last_game_segments = [None for _ in range(env_nums)] + last_game_priorities = [None for _ in range(env_nums)] + # for priorities in self-play + search_values_lst = [[] for _ in range(env_nums)] + pred_values_lst = [[] for _ in range(env_nums)] + if self.policy_config.gumbel_algo: + improved_policy_lst = [[] for _ in range(env_nums)] + + # some logs + eps_steps_lst, visit_entropies_lst = np.zeros(env_nums), np.zeros(env_nums) + if self.policy_config.gumbel_algo: + completed_value_lst = np.zeros(env_nums) + self_play_moves = 0. + self_play_episodes = 0. + self_play_moves_max = 0 + self_play_visit_entropy = [] + total_transitions = 0 + + ready_env_id = set() + remain_episode = n_episode + if collect_with_pure_policy: + temp_visit_list = [0.0 for i in range(self._env.action_space.n)] + + self.collected_game_segments = 0 + one_episode_done = False + collected_enough_segments = False # 标志变量 + + while True: + with self._timer: + # Get current ready env obs. + obs = self._env.ready_obs + new_available_env_id = set(obs.keys()).difference(ready_env_id) + ready_env_id = ready_env_id.union(set(list(new_available_env_id)[:remain_episode])) + remain_episode -= min(len(new_available_env_id), remain_episode) + + stack_obs = {env_id: game_segments[env_id].get_obs() for env_id in ready_env_id} + stack_obs = list(stack_obs.values()) + + action_mask_dict = {env_id: action_mask_dict[env_id] for env_id in ready_env_id} + to_play_dict = {env_id: to_play_dict[env_id] for env_id in ready_env_id} + action_mask = [action_mask_dict[env_id] for env_id in ready_env_id] + to_play = [to_play_dict[env_id] for env_id in ready_env_id] + if self.policy_config.use_ture_chance_label_in_chance_encoder: + chance_dict = {env_id: chance_dict[env_id] for env_id in ready_env_id} + + stack_obs = to_ndarray(stack_obs) + # return stack_obs shape: [B, S*C, W, H] e.g. [8, 4*1, 96, 96] + stack_obs = prepare_observation(stack_obs, self.policy_config.model.model_type) + stack_obs = torch.from_numpy(stack_obs).to(self.policy_config.device) + + # ============================================================== + # Key policy forward step + # ============================================================== + # print(f'ready_env_id:{ready_env_id}') + policy_output = self._policy.forward(stack_obs, action_mask, temperature, to_play, epsilon, ready_env_id=ready_env_id) + + # Extract relevant policy outputs + actions_with_env_id = {k: v['action'] for k, v in policy_output.items()} + value_dict_with_env_id = {k: v['searched_value'] for k, v in policy_output.items()} + pred_value_dict_with_env_id = {k: v['predicted_value'] for k, v in policy_output.items()} + + if self.policy_config.sampled_algo: + root_sampled_actions_dict_with_env_id = { + k: v['root_sampled_actions'] for k, v in policy_output.items() + } + + if not collect_with_pure_policy: + distributions_dict_with_env_id = {k: v['visit_count_distributions'] for k, v in + policy_output.items()} + visit_entropy_dict_with_env_id = {k: v['visit_count_distribution_entropy'] for k, v in + policy_output.items()} + + if self.policy_config.gumbel_algo: + improved_policy_dict_with_env_id = {k: v['improved_policy_probs'] for k, v in + policy_output.items()} + completed_value_with_env_id = {k: v['roots_completed_value'] for k, v in policy_output.items()} + + # Initialize dictionaries to store results + actions = {} + value_dict = {} + pred_value_dict = {} + + if not collect_with_pure_policy: + distributions_dict = {} + visit_entropy_dict = {} + + if self.policy_config.sampled_algo: + root_sampled_actions_dict = {} + + if self.policy_config.gumbel_algo: + improved_policy_dict = {} + completed_value_dict = {} + + # Populate the result dictionaries + for env_id in ready_env_id: + actions[env_id] = actions_with_env_id.pop(env_id) + value_dict[env_id] = value_dict_with_env_id.pop(env_id) + pred_value_dict[env_id] = pred_value_dict_with_env_id.pop(env_id) + + if not collect_with_pure_policy: + distributions_dict[env_id] = distributions_dict_with_env_id.pop(env_id) + + if self.policy_config.sampled_algo: + root_sampled_actions_dict[env_id] = root_sampled_actions_dict_with_env_id.pop(env_id) + + visit_entropy_dict[env_id] = visit_entropy_dict_with_env_id.pop(env_id) + + if self.policy_config.gumbel_algo: + improved_policy_dict[env_id] = improved_policy_dict_with_env_id.pop(env_id) + completed_value_dict[env_id] = completed_value_with_env_id.pop(env_id) + + # ============================================================== + # Interact with the environment + # ============================================================== + timesteps = self._env.step(actions) + + interaction_duration = self._timer.value / len(timesteps) + + for env_id, timestep in timesteps.items(): + with self._timer: + if timestep.info.get('abnormal', False): + # If there is an abnormal timestep, reset all the related variables(including this env). + # suppose there is no reset param, reset this env + self._env.reset({env_id: None}) + self._policy.reset([env_id]) + self._reset_stat(env_id) + self._logger.info('Env{} returns a abnormal step, its info is {}'.format(env_id, timestep.info)) + continue + obs, reward, done, info = timestep.obs, timestep.reward, timestep.done, timestep.info + + if collect_with_pure_policy: + game_segments[env_id].store_search_stats(temp_visit_list, 0) + else: + if self.policy_config.sampled_algo: + game_segments[env_id].store_search_stats( + distributions_dict[env_id], value_dict[env_id], root_sampled_actions_dict[env_id] + ) + elif self.policy_config.gumbel_algo: + game_segments[env_id].store_search_stats(distributions_dict[env_id], value_dict[env_id], + improved_policy=improved_policy_dict[env_id]) + else: + game_segments[env_id].store_search_stats(distributions_dict[env_id], value_dict[env_id]) + + # append a transition tuple, including a_t, o_{t+1}, r_{t}, action_mask_{t}, to_play_{t} + # in ``game_segments[env_id].init``, we have appended o_{t} in ``self.obs_segment`` + if self.policy_config.use_ture_chance_label_in_chance_encoder: + game_segments[env_id].append( + actions[env_id], to_ndarray(obs['observation']), reward, action_mask_dict[env_id], + to_play_dict[env_id], chance_dict[env_id] + ) + else: + game_segments[env_id].append( + actions[env_id], to_ndarray(obs['observation']), reward, action_mask_dict[env_id], + to_play_dict[env_id] + ) + + # NOTE: the position of code snippet is very important. + # the obs['action_mask'] and obs['to_play'] are corresponding to the next action + action_mask_dict[env_id] = to_ndarray(obs['action_mask']) + to_play_dict[env_id] = to_ndarray(obs['to_play']) + if self.policy_config.use_ture_chance_label_in_chance_encoder: + chance_dict[env_id] = to_ndarray(obs['chance']) + + if self.policy_config.ignore_done: + dones[env_id] = False + else: + dones[env_id] = done + + if not collect_with_pure_policy: + visit_entropies_lst[env_id] += visit_entropy_dict[env_id] + if self.policy_config.gumbel_algo: + completed_value_lst[env_id] += np.mean(np.array(completed_value_dict[env_id])) + + eps_steps_lst[env_id] += 1 + if self._policy.get_attribute('cfg').type == 'unizero': + # only for UniZero now + self._policy.reset(env_id=env_id, current_steps=eps_steps_lst[env_id], reset_init_data=False) + + total_transitions += 1 + + if self.policy_config.use_priority: + pred_values_lst[env_id].append(pred_value_dict[env_id]) + search_values_lst[env_id].append(value_dict[env_id]) + if self.policy_config.gumbel_algo and not collect_with_pure_policy: + improved_policy_lst[env_id].append(improved_policy_dict[env_id]) + + # append the newest obs + observation_window_stack[env_id].append(to_ndarray(obs['observation'])) + + # ============================================================== + # we will save a game segment if it is the end of the game or the next game segment is finished. + # ============================================================== + + # if game segment is full, we will save the last game segment + if game_segments[env_id].is_full(): + # pad over last segment trajectory + if last_game_segments[env_id] is not None: + # TODO(pu): return the one game segment + self.pad_and_save_last_trajectory( + env_id, last_game_segments, last_game_priorities, game_segments, dones + ) + + # calculate priority + priorities = self._compute_priorities(env_id, pred_values_lst, search_values_lst) + pred_values_lst[env_id] = [] + search_values_lst[env_id] = [] + if self.policy_config.gumbel_algo and not collect_with_pure_policy: + improved_policy_lst[env_id] = [] + + # the current game_segments become last_game_segment + last_game_segments[env_id] = game_segments[env_id] + last_game_priorities[env_id] = priorities + + # create new GameSegment + game_segments[env_id] = GameSegment( + self._env.action_space, + game_segment_length=self.policy_config.game_segment_length, + config=self.policy_config + ) + game_segments[env_id].reset(observation_window_stack[env_id]) + + self._env_info[env_id]['step'] += 1 + collected_step += 1 + + self._env_info[env_id]['time'] += self._timer.value + interaction_duration + if timestep.done: + one_episode_done = True + self._total_episode_count += 1 + + reward = timestep.info['eval_episode_return'] + info = { + 'reward': reward, + 'time': self._env_info[env_id]['time'], + 'step': self._env_info[env_id]['step'], + } + if not collect_with_pure_policy: + info['visit_entropy'] = visit_entropies_lst[env_id] / eps_steps_lst[env_id] + if self.policy_config.gumbel_algo: + info['completed_value'] = completed_value_lst[env_id] / eps_steps_lst[env_id] + + collected_episode += 1 + self._episode_info.append(info) + + # ============================================================== + # if it is the end of the game, we will save the game segment + # ============================================================== + + # NOTE: put the penultimate game segment in one episode into the trajectory_pool + # pad over 2th last game_segment using the last game_segment + if last_game_segments[env_id] is not None: + self.pad_and_save_last_trajectory( + env_id, last_game_segments, last_game_priorities, game_segments, dones + ) + + # store current segment trajectory + priorities = self._compute_priorities(env_id, pred_values_lst, search_values_lst) + + # NOTE: put the last game segment in one episode into the trajectory_pool + game_segments[env_id].game_segment_to_array() + + # assert len(game_segments[env_id]) == len(priorities) + # NOTE: save the last game segment in one episode into the trajectory_pool if it's not null + if len(game_segments[env_id].reward_segment) != 0: + self.game_segment_pool.append((game_segments[env_id], priorities, dones[env_id])) + + # print(game_segments[env_id].reward_segment) + # reset the finished env and init game_segments + if n_episode > self._env_num: + # Get current ready env obs. + init_obs = self._env.ready_obs + retry_waiting_time = 0.001 + while len(init_obs.keys()) != self._env_num: + # In order to be compatible with subprocess env_manager, in which sometimes self._env_num is not equal to + # len(self._env.ready_obs), especially in tictactoe env. + self._logger.info('The current init_obs.keys() is {}'.format(init_obs.keys())) + self._logger.info('Before sleeping, the _env_states is {}'.format(self._env._env_states)) + time.sleep(retry_waiting_time) + self._logger.info( + '=' * 10 + 'Wait for all environments (subprocess) to finish resetting.' + '=' * 10 + ) + self._logger.info( + 'After sleeping {}s, the current _env_states is {}'.format( + retry_waiting_time, self._env._env_states + ) + ) + init_obs = self._env.ready_obs + + new_available_env_id = set(init_obs.keys()).difference(ready_env_id) + ready_env_id = ready_env_id.union(set(list(new_available_env_id)[:remain_episode])) + remain_episode -= min(len(new_available_env_id), remain_episode) + + action_mask_dict[env_id] = to_ndarray(init_obs[env_id]['action_mask']) + to_play_dict[env_id] = to_ndarray(init_obs[env_id]['to_play']) + if self.policy_config.use_ture_chance_label_in_chance_encoder: + chance_dict[env_id] = to_ndarray(init_obs[env_id]['chance']) + + game_segments[env_id] = GameSegment( + self._env.action_space, + game_segment_length=self.policy_config.game_segment_length, + config=self.policy_config + ) + observation_window_stack[env_id] = deque( + [init_obs[env_id]['observation'] for _ in range(self.policy_config.model.frame_stack_num)], + maxlen=self.policy_config.model.frame_stack_num + ) + game_segments[env_id].reset(observation_window_stack[env_id]) + last_game_segments[env_id] = None + last_game_priorities[env_id] = None + + # log + self_play_moves_max = max(self_play_moves_max, eps_steps_lst[env_id]) + if not collect_with_pure_policy: + self_play_visit_entropy.append(visit_entropies_lst[env_id] / eps_steps_lst[env_id]) + self_play_moves += eps_steps_lst[env_id] + self_play_episodes += 1 + + pred_values_lst[env_id] = [] + search_values_lst[env_id] = [] + eps_steps_lst[env_id] = 0 + visit_entropies_lst[env_id] = 0 + + # Env reset is done by env_manager automatically + self._policy.reset([env_id]) # NOTE: reset the policy for the env_id. Default reset_init_data=True. + self._reset_stat(env_id) + ready_env_id.remove(env_id) + + # if collected_episode >= n_episode: + # if self.collected_game_segments >= self._default_num_segments or one_episode_done: # game_segment_length = 400 + # print(f'collect {self.collected_game_segments} segments now! one_episode_done: {one_episode_done}') + # self.collected_game_segments = 0 + + # 如果放到for循环里面去的v1版本,应该是丢失了部分环境的样本 + # 下面的v2版本,是将个环境的样本都正确返回了 + if len(self.game_segment_pool) >= self._default_num_segments or one_episode_done: # game_segment_length = 400 + print(f'collect {len(self.game_segment_pool)} segments now! one_episode_done: {one_episode_done}') + collected_enough_segments = True # 条件满足,设置标志变量为 True + # one_episode_done = False + + # [data, meta_data] + return_data = [self.game_segment_pool[i][0] for i in range(len(self.game_segment_pool))], [ + { + 'priorities': self.game_segment_pool[i][1], + 'done': self.game_segment_pool[i][2], + 'unroll_plus_td_steps': self.unroll_plus_td_steps + } for i in range(len(self.game_segment_pool)) + ] + self.game_segment_pool.clear() + break + + # if collected_enough_segments: # 如果满足条件,则跳出外层 for 循环 + # break + + collected_duration = sum([d['time'] for d in self._episode_info]) + # reduce data when enables DDP + if self._world_size > 1: + collected_step = allreduce_data(collected_step, 'sum') + collected_episode = allreduce_data(collected_episode, 'sum') + collected_duration = allreduce_data(collected_duration, 'sum') + self._total_envstep_count += collected_step + self._total_episode_count += collected_episode + self._total_duration += collected_duration + + # log + self._output_log(train_iter) + return return_data + + def _output_log(self, train_iter: int) -> None: + """ + Overview: + Log the collector's data and output the log information. + Arguments: + - train_iter (:obj:`int`): Current training iteration number for logging context. + """ + if self._rank != 0: + return + if (train_iter - self._last_train_iter) >= self._collect_print_freq and len(self._episode_info) > 0: + self._last_train_iter = train_iter + episode_count = len(self._episode_info) + envstep_count = sum([d['step'] for d in self._episode_info]) + duration = sum([d['time'] for d in self._episode_info]) + episode_reward = [d['reward'] for d in self._episode_info] + if not self.collect_with_pure_policy: + visit_entropy = [d['visit_entropy'] for d in self._episode_info] + else: + visit_entropy = [0.0] + if self.policy_config.gumbel_algo: + completed_value = [d['completed_value'] for d in self._episode_info] + self._total_duration += duration + info = { + 'episode_count': episode_count, + 'envstep_count': envstep_count, + 'avg_envstep_per_episode': envstep_count / episode_count, + 'avg_envstep_per_sec': envstep_count / duration, + 'avg_episode_per_sec': episode_count / duration, + 'collect_time': duration, + 'reward_mean': np.mean(episode_reward), + 'reward_std': np.std(episode_reward), + 'reward_max': np.max(episode_reward), + 'reward_min': np.min(episode_reward), + 'total_envstep_count': self._total_envstep_count, + 'total_episode_count': self._total_episode_count, + 'total_duration': self._total_duration, + 'visit_entropy': np.mean(visit_entropy), + } + if self.policy_config.gumbel_algo: + info['completed_value'] = np.mean(completed_value) + self._episode_info.clear() + self._logger.info("collect end:\n{}".format('\n'.join(['{}: {}'.format(k, v) for k, v in info.items()]))) + for k, v in info.items(): + if k in ['each_reward']: + continue + self._tb_logger.add_scalar('{}_iter/'.format(self._instance_name) + k, v, train_iter) + if k in ['total_envstep_count']: + continue + self._tb_logger.add_scalar('{}_step/'.format(self._instance_name) + k, v, self._total_envstep_count) \ No newline at end of file diff --git a/zoo/atari/config/atari_unizero_sgement_config.py b/zoo/atari/config/atari_unizero_sgement_config.py index 157bad577..50ff0cb82 100644 --- a/zoo/atari/config/atari_unizero_sgement_config.py +++ b/zoo/atari/config/atari_unizero_sgement_config.py @@ -17,24 +17,28 @@ num_segments = 8 # num_segments = 1 game_segment_length=20 +# game_segment_length=100 +# game_segment_length=400 evaluator_env_num = 3 num_simulations = 50 -max_env_step = int(3e5) +max_env_step = int(5e5) reanalyze_ratio = 0. batch_size = 64 num_unroll_steps = 10 infer_context_length = 4 # ====== only for debug ===== -collector_env_num = 2 -num_segments = 1 -evaluator_env_num = 2 -num_simulations = 3 -max_env_step = int(2e5) -reanalyze_ratio = 0. -batch_size = 2 -num_unroll_steps = 10 +# collector_env_num = 8 +# num_segments = 8 +# evaluator_env_num = 2 +# num_simulations = 3 +# max_env_step = int(2e5) +# reanalyze_ratio = 0. +# batch_size = 2 +# num_unroll_steps = 10 +# game_segment_length=20 + # ============================================================== # end of the most frequently changed config specified by the user # ============================================================== @@ -51,8 +55,8 @@ n_evaluator_episode=evaluator_env_num, manager=dict(shared_memory=False, ), # TODO: only for debug - # collect_max_episode_steps=int(50), - # eval_max_episode_steps=int(50), + # collect_max_episode_steps=int(100), + # eval_max_episode_steps=int(100), ), policy=dict( model=dict( @@ -86,6 +90,7 @@ num_simulations=num_simulations, reanalyze_ratio=reanalyze_ratio, num_segments=num_segments, + train_start_after_envsteps=2000, game_segment_length=game_segment_length, # debug replay_buffer_size=int(1e6), eval_freq=int(5e3), @@ -123,8 +128,9 @@ seeds = [0] # You can add more seed values here for seed in seeds: # Update exp_name to include the current seed - # main_config.exp_name = f'data_efficiency0829_plus_tune-uz_0912/numsegments-{num_segments}_gsl{game_segment_length}_v2/obshape96_no-augmentation/{env_id[:-14]}_stack1_unizero_upc{update_per_collect}-rr{replay_ratio}_H{num_unroll_steps}_bs{batch_size}_seed{seed}_nlayer2' - main_config.exp_name = f'data_efficiency0829_plus_tune-uz_debug/obshape96_use-augmentation/{env_id[:-14]}_stack1_unizero_upc{update_per_collect}-rr{replay_ratio}_H{num_unroll_steps}_bs{batch_size}_seed{seed}_nlayer2' + main_config.exp_name = f'data_efficiency0829_plus_tune-uz_0912/numsegments-{num_segments}_gsl{game_segment_length}_fix/obshape96_no-augmentation/{env_id[:-14]}_stack1_unizero_upc{update_per_collect}-rr{replay_ratio}_H{num_unroll_steps}_bs{batch_size}_seed{seed}_nlayer2' + # main_config.exp_name = f'data_efficiency0829_plus_tune-uz_0912/numsegments-{num_segments}_gsl{game_segment_length}_fix/obshape96_use-augmentation-obsw10/{env_id[:-14]}_stack1_unizero_upc{update_per_collect}-rr{replay_ratio}_H{num_unroll_steps}_bs{batch_size}_seed{seed}_nlayer2' + # main_config.exp_name = f'data_efficiency0829_plus_tune-uz_0912_debug/numsegments-{num_segments}_gsl{game_segment_length}_fix/obshape96_use-augmentation-obsw10/{env_id[:-14]}_stack1_unizero_upc{update_per_collect}-rr{replay_ratio}_H{num_unroll_steps}_bs{batch_size}_seed{seed}_nlayer2' from lzero.entry import train_unizero train_unizero([main_config, create_config], seed=seed, model_path=main_config.policy.model_path, max_env_step=max_env_step) diff --git a/zoo/atari/envs/atari_lightzero_env.py b/zoo/atari/envs/atari_lightzero_env.py index 84288feb5..281292740 100644 --- a/zoo/atari/envs/atari_lightzero_env.py +++ b/zoo/atari/envs/atari_lightzero_env.py @@ -132,7 +132,10 @@ def reset(self) -> dict: self.obs = to_ndarray(obs) self._eval_episode_return = 0. + self.timestep = 0 + obs = self.observe() + return obs def step(self, action: int) -> BaseEnvTimestep: @@ -148,6 +151,8 @@ def step(self, action: int) -> BaseEnvTimestep: self.obs = to_ndarray(obs) self.reward = np.array(reward).astype(np.float32) self._eval_episode_return += self.reward + self.timestep += 1 + # print(f'self.timestep: {self.timestep}') observation = self.observe() if done: info['eval_episode_return'] = self._eval_episode_return @@ -169,6 +174,7 @@ def observe(self) -> dict: observation = np.transpose(observation, (2, 0, 1)) action_mask = np.ones(self._action_space.n, 'int8') + # return {'observation': observation, 'action_mask': action_mask, 'to_play': -1, 'timestep': self.timestep} return {'observation': observation, 'action_mask': action_mask, 'to_play': -1} @property