Skip to content

Commit

Permalink
polish(nyz): polish sac comments
Browse files Browse the repository at this point in the history
  • Loading branch information
PaParaZz1 committed Oct 8, 2023
1 parent 203c646 commit 78a60c1
Show file tree
Hide file tree
Showing 6 changed files with 308 additions and 99 deletions.
4 changes: 2 additions & 2 deletions ding/policy/base_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,8 +318,8 @@ def collect_mode(self) -> 'Policy.collect_function': # noqa
to define immutable interfaces and restrict the usage of policy in different mode. Moreover, derived \
subclass can override the interfaces to customize its own collect mode.
Returns:
- interfaces (:obj:`Policy.collect_function`): The interfaces of collect mode of policy, it is a namedtuple \
whose values of distinct fields are different internal methods.
- interfaces (:obj:`Policy.collect_function`): The interfaces of collect mode of policy, it is a \
namedtuple whose values of distinct fields are different internal methods.
Examples:
>>> policy = Policy(cfg, model)
>>> policy_collect = policy.collect_mode
Expand Down
37 changes: 31 additions & 6 deletions ding/policy/common_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Any, Dict
from typing import List, Any, Dict, Callable
import torch
import numpy as np
import treetensor.torch as ttorch
Expand Down Expand Up @@ -56,7 +56,7 @@ def default_preprocess_learn(
else:
data['weight'] = data.get('weight', None)
if use_nstep:
# Reward reshaping for n-step
# reward reshaping for n-step
reward = data['reward']
if len(reward.shape) == 1:
reward = reward.unsqueeze(1)
Expand All @@ -69,10 +69,22 @@ def default_preprocess_learn(
return data


def single_env_forward_wrapper(forward_fn):
def single_env_forward_wrapper(forward_fn: Callable) -> Callable:
"""
Overview:
Wrap policy to support gym-style interaction between policy and environment.
Wrap policy to support gym-style interaction between policy and single environment.
Arguments:
- forward_fn (:obj:`Callable`): The original forward function of policy.
Returns:
- wrapped_forward_fn (:obj:`Callable`): The wrapped forward function of policy.
Examples:
>>> env = gym.make('CartPole-v0')
>>> policy = DQNPolicy(...)
>>> forward_fn = single_env_forward_wrapper(policy.eval_mode.forward)
>>> obs = env.reset()
>>> action = forward_fn(obs)
>>> next_obs, rew, done, info = env.step(action)
"""

def _forward(obs):
Expand All @@ -84,10 +96,23 @@ def _forward(obs):
return _forward


def single_env_forward_wrapper_ttorch(forward_fn, cuda=True):
def single_env_forward_wrapper_ttorch(forward_fn: Callable, cuda: bool = True) -> Callable:
"""
Overview:
Wrap policy to support gym-style interaction between policy and environment for treetensor (ttorch) data.
Wrap policy to support gym-style interaction between policy and single environment for treetensor (ttorch) data.
Arguments:
- forward_fn (:obj:`Callable`): The original forward function of policy.
- cuda (:obj:`bool`): Whether to use cuda in policy, if True, this function will move the input data to cuda.
Returns:
- wrapped_forward_fn (:obj:`Callable`): The wrapped forward function of policy.
Examples:
>>> env = gym.make('CartPole-v0')
>>> policy = PPOFPolicy(...)
>>> forward_fn = single_env_forward_wrapper_ttorch(policy.eval)
>>> obs = env.reset()
>>> action = forward_fn(obs)
>>> next_obs, rew, done, info = env.step(action)
"""

def _forward(obs):
Expand Down
15 changes: 7 additions & 8 deletions ding/policy/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,7 @@ class DDPGPolicy(Policy):
type='ddpg',
# (bool) Whether to use cuda in policy.
cuda=False,
# (bool) Whether learning policy is the same as collecting data policy(on-policy).
# Default False in DDPG.
# (bool) Whether learning policy is the same as collecting data policy(on-policy). Default False in DDPG.
on_policy=False,
# (bool) Whether to enable priority experience sample.
priority=False,
Expand All @@ -84,7 +83,7 @@ class DDPGPolicy(Policy):
multi_agent=False,
# learn_mode config
learn=dict(
# How many updates(iterations) to train after collector's one collection.
# (int) How many updates(iterations) to train after collector's one collection.
# Bigger "update_per_collect" means bigger off-policy.
# collect data -> update policy-> collect data -> ...
update_per_collect=1,
Expand Down Expand Up @@ -150,7 +149,7 @@ def default_model(self) -> Tuple[str, List[str]]:
return 'continuous_qac', ['ding.model.template.qac']

def _init_learn(self) -> None:
r"""
"""
Overview:
Learn mode init method. Called by ``self.__init__``.
Init actor and critic optimizers, algorithm config, main and target models.
Expand Down Expand Up @@ -202,7 +201,7 @@ def _init_learn(self) -> None:
self._forward_learn_cnt = 0 # count iterations

def _forward_learn(self, data: dict) -> Dict[str, Any]:
r"""
"""
Overview:
Forward and backward function of learn mode.
Arguments:
Expand Down Expand Up @@ -343,7 +342,7 @@ def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None:
self._optimizer_critic.load_state_dict(state_dict['optimizer_critic'])

def _init_collect(self) -> None:
r"""
"""
Overview:
Collect mode init method. Called by ``self.__init__``.
Init traj and unroll length, collect model.
Expand All @@ -365,7 +364,7 @@ def _init_collect(self) -> None:
self._collect_model.reset()

def _forward_collect(self, data: dict, **kwargs) -> dict:
r"""
"""
Overview:
Forward function of collect mode.
Arguments:
Expand Down Expand Up @@ -431,7 +430,7 @@ def _get_train_sample(self, transitions: List[Dict[str, Any]]) -> List[Dict[str,
return get_train_sample(transitions, self._unroll_len)

def _init_eval(self) -> None:
r"""
"""
Overview:
Evaluate mode init method. Called by ``self.__init__``.
Init eval model. Unlike learn and collect model, eval model does not need noise.
Expand Down
41 changes: 36 additions & 5 deletions ding/policy/policy_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,32 @@


class PolicyFactory:
r"""
"""
Overview:
Pure random policy. Only used for initial sample collecting if `cfg.policy.random_collect_size` > 0.
Policy factory class, used to generate different policies for general purpose. Such as random action policy, \
which is used for initial sample collecting for better exploration when ``random_collect_size`` > 0.
Interfaces:
``get_random_policy``
"""

@staticmethod
def get_random_policy(
policy: 'BasePolicy', # noqa
policy: 'Policy.collect_mode', # noqa
action_space: 'gym.spaces.Space' = None, # noqa
forward_fn: Callable = None,
) -> None:
) -> 'Policy.collect_mode': # noqa
"""
Overview:
According to the given action space, define the forward function of the random policy, then pack it with \
other interfaces of the given policy, and return the final collect mode interfaces of policy.
Arguments:
- policy (:obj:`Policy.collect_mode`): The collect mode interfaces of the policy.
- action_space (:obj:`gym.spaces.Space`): The action space of the environment, gym-style.
- forward_fn (:obj:`Callable`): It action space is too complex, you can define your own forward function \
and pass it to this function, note you should set ``action_space`` to ``None`` in this case.
Returns:
- random_policy (:obj:`Policy.collect_mode`): The collect mode intefaces of the random policy.
"""
assert not (action_space is None and forward_fn is None)
random_collect_function = namedtuple(
'random_collect_function', [
Expand Down Expand Up @@ -69,7 +84,23 @@ def reset(*args, **kwargs) -> None:
)


def get_random_policy(cfg: EasyDict, policy: 'Policy.collect_mode', env: 'BaseEnvManager'): # noqa
def get_random_policy(
cfg: EasyDict,
policy: 'Policy.collect_mode', # noqa
env: 'BaseEnvManager' # noqa
) -> 'Policy.collect_mode': # noqa
"""
Overview:
The entry function to get the corresponding random policy. If a policy needs special data items in a \
transition, then return itself, otherwise, we will use ``PolicyFactory`` to return a general random policy.
Arguments:
- cfg (:obj:`EasyDict`): The EasyDict-type dict configuration.
- policy (:obj:`Policy.collect_mode`): The collect mode interfaces of the policy.
- env (:obj:`BaseEnvManager`): The env manager instance, which is used to get the action space for random \
action generation.
Returns:
- random_policy (:obj:`Policy.collect_mode`): The collect mode intefaces of the random policy.
"""
if cfg.policy.get('transition_with_policy_data', False):
return policy
else:
Expand Down
Loading

0 comments on commit 78a60c1

Please sign in to comment.