Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make SequenceLearner points hashable by passing the sequence to the function. #266

Closed
wants to merge 9 commits into from
50 changes: 15 additions & 35 deletions adaptive/learner/sequence_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,28 +5,15 @@
from adaptive.learner.base_learner import BaseLearner


class _IgnoreFirstArgument:
"""Remove the first argument from the call signature.
class _CallFromSequence:
"""Call function with index of sequence."""

The SequenceLearner's function receives a tuple ``(index, point)``
but the original function only takes ``point``.

This is the same as `lambda x: function(x[1])`, however, that is not
pickable.
"""

def __init__(self, function):
def __init__(self, function, sequence):
self.function = function
self.sequence = sequence

def __call__(self, index_point, *args, **kwargs):
index, point = index_point
return self.function(point, *args, **kwargs)

def __getstate__(self):
return self.function

def __setstate__(self, function):
self.__init__(function)
def __call__(self, index, *args, **kwargs):
return self.function(self.sequence[index], *args, **kwargs)


class SequenceLearner(BaseLearner):
Expand All @@ -40,7 +27,7 @@ class SequenceLearner(BaseLearner):
Parameters
----------
function : callable
The function to learn. Must take a single element `sequence`.
The function to learn. Must take a single element of `sequence`.
sequence : sequence
The sequence to learn.

Expand All @@ -58,7 +45,7 @@ class SequenceLearner(BaseLearner):

def __init__(self, function, sequence):
self._original_function = function
self.function = _IgnoreFirstArgument(function)
self.function = _CallFromSequence(function, sequence)
self._to_do_indices = SortedSet({i for i, _ in enumerate(sequence)})
self._ntotal = len(sequence)
self.sequence = copy(sequence)
Expand All @@ -67,31 +54,26 @@ def __init__(self, function, sequence):

def ask(self, n, tell_pending=True):
indices = []
points = []
loss_improvements = []
for index in self._to_do_indices:
if len(points) >= n:
if len(indices) >= n:
break
point = self.sequence[index]
indices.append(index)
points.append((index, point))
loss_improvements.append(1 / self._ntotal)

if tell_pending:
for i, p in zip(indices, points):
self.tell_pending((i, p))
for index in indices:
self.tell_pending(index)

return points, loss_improvements
return indices, loss_improvements

def _get_data(self):
return self.data

def _set_data(self, data):
if data:
indices, values = zip(*data.items())
# the points aren't used by tell, so we can safely pass None
points = [(i, None) for i in indices]
self.tell_many(points, values)
self.tell_many(indices, values)

def loss(self, real=True):
if not (self._to_do_indices or self.pending_points):
Expand All @@ -105,14 +87,12 @@ def remove_unfinished(self):
self._to_do_indices.add(i)
self.pending_points = set()

def tell(self, point, value):
index, point = point
def tell(self, index, value):
self.data[index] = value
self.pending_points.discard(index)
self._to_do_indices.discard(index)

def tell_pending(self, point):
index, point = point
def tell_pending(self, index):
self.pending_points.add(index)
self._to_do_indices.discard(index)

Expand Down
31 changes: 5 additions & 26 deletions adaptive/tests/test_learners.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,19 +281,9 @@ def test_adding_existing_data_is_idempotent(learner_type, f, learner_kwargs):
M = random.randint(10, 30)
pls = zip(*learner.ask(M))
cpls = zip(*control.ask(M))
if learner_type is SequenceLearner:
# The SequenceLearner's points might not be hasable
points, values = zip(*pls)
indices, points = zip(*points)

cpoints, cvalues = zip(*cpls)
cindices, cpoints = zip(*cpoints)
assert (np.array(points) == np.array(cpoints)).all()
assert values == cvalues
assert indices == cindices
else:
# Point ordering is not defined, so compare as sets
assert set(pls) == set(cpls)
# Point ordering is not defined, so compare as sets
assert set(pls) == set(cpls)


# XXX: This *should* pass (https://github.com/python-adaptive/adaptive/issues/55)
Expand Down Expand Up @@ -324,20 +314,9 @@ def test_adding_non_chosen_data(learner_type, f, learner_kwargs):
pls = zip(*learner.ask(M))
cpls = zip(*control.ask(M))

if learner_type is SequenceLearner:
# The SequenceLearner's points might not be hasable
points, values = zip(*pls)
indices, points = zip(*points)

cpoints, cvalues = zip(*cpls)
cindices, cpoints = zip(*cpoints)
assert (np.array(points) == np.array(cpoints)).all()
assert values == cvalues
assert indices == cindices
else:
# Point ordering within a single call to 'ask'
# is not guaranteed to be the same by the API.
assert set(pls) == set(cpls)
# Point ordering within a single call to 'ask'
# is not guaranteed to be the same by the API.
assert set(pls) == set(cpls)


@run_with(Learner1D, xfail(Learner2D), xfail(LearnerND), AverageLearner)
Expand Down
28 changes: 28 additions & 0 deletions adaptive/tests/test_sequence_learner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import asyncio

import numpy as np

from adaptive import Runner, SequenceLearner
from adaptive.runner import SequentialExecutor


class FailOnce:
def __init__(self):
self.failed = False

def __call__(self, value):
if self.failed:
return value
self.failed = True
raise Exception
basnijholt marked this conversation as resolved.
Show resolved Hide resolved


def test_fail_with_sequence_of_unhashable():
# https://github.com/python-adaptive/adaptive/issues/265
seq = [dict(x=x) for x in np.linspace(-1, 1, 101)] # unhashable
basnijholt marked this conversation as resolved.
Show resolved Hide resolved
learner = SequenceLearner(FailOnce(), sequence=seq)
runner = Runner(
learner, goal=SequenceLearner.done, retries=100, executor=SequentialExecutor()
basnijholt marked this conversation as resolved.
Show resolved Hide resolved
) # with 100 retries the test will fail once in 10^31
asyncio.get_event_loop().run_until_complete(runner.task)
assert runner.status() == "finished"