Implementing Go's singleflight.Group in Python

While working on some concurrent code in Go, I’ve found a great little package in the “extended” standard library called singleflight. Since I write a lot of asynchronous Python at work, I thought it’d be fun to port this concurrency primitive in Python.

What is singleflight

According to the package documentation:

Package singleflight provides a duplicate function call suppression mechanism.

It exports a single type, the Group, that lets multiple goroutines trying to call the same function concurrently to share the same result, avoiding potentially unwanted duplicate computation.

Suppressing duplicate database queries with singleflight
Suppressing duplicate database queries with singleflight

I’m not going to go over how it works in great detail – you can read more about that in this blog post by VictoriaMetrics – we’re here to focus on the Python implementation side of it.

The code

Let’s start by defining our Group class and its public interface:

 1from collections.abc import Callable, Coroutine
 2
 3class Call[T]:
 4    value: T
 5
 6    def __init__(self) -> None:
 7        self.callers = 0
 8        # Go uses a sync.WaitGroup instead.
 9        self.done = asyncio.Event()
10
11
12class Group:
13    def __init__(self) -> None:
14        self._lock = asyncio.Lock()
15        self._calls: dict[str, Call] = {}
16
17    async def do[R, **P](
18        self,
19        key: str,
20        f: Callable[P, Coroutine[None, None, R]],
21        *args: P.args,
22        **kwargs: P.kwargs,
23    ) -> tuple[R, bool]:
24        ...
25
26    async def forget(self, key: str) -> None:
27        ...

While the forget() method is pretty self-explanatory, the type signature of do() may be difficult to grasp at first glance. However, if we were to ignore the type annotations, we should see that it takes a string key, a callable f, and a set of arguments to pass to f. The return value of do() is a tuple of the result of f and a boolean flag indicating whether the result was shared with other coroutines.

The internals of do() look like so:

 1async def do[R, **P](
 2    self,
 3    key: str,
 4    f: Callable[P, Coroutine[None, None, R]],
 5    *args: P.args,
 6    **kwargs: P.kwargs,
 7) -> tuple[R, bool]:
 8    await self._lock.acquire()
 9
10    if key in self._calls:
11        call = self._calls[key]
12        call.callers += 1
13        self._lock.release()
14
15        await call.done.wait()
16        return call.value, True
17
18    call = Call[R]()
19    self._calls[key] = call
20    self._lock.release()
21
22    call.value = await f(*args, **kwargs)
23
24    async with self._lock:
25        call.done.set()
26        # NOTE: The identity comparison is important here because other coroutine
27        # might've called forget() at this point and cleared the value.
28        if key in self._calls and self._calls[key] is call:
29            del self._calls[key]
30
31    # 'callers' is only incremented by secondary coroutines,
32    # meaning if it's 0 then the result was never shared.
33    return call.value, call.callers > 0

After acquiring the lock, we check for an existing in-flight call. If the call exists, we increment the number of callers, wait for the done event to be set, and return the result of the call. As you can see, if the if branch is executed, we return from do before ever calling f(). That’s why we accept a coroutine-returning callable (an async function) and a set of arguments instead of a coroutine object, which would never have been awaited.

In the case when we are the first caller, we initialize a new Call instance and put in the dict. Once f() has finished, we mark the done event as set, waking other coroutines waiting for the same call to finish, and return the result.

Let’s now see this code in action. We’re going to simulate a complex task by sleeping for 3 seconds and returning a randomly-generated number:

 1async def compute() -> int:
 2    print("compute() called")
 3    await asyncio.sleep(3)
 4    return random.randint(1, 100)
 5
 6
 7async def main():
 8    sf = Group()
 9    async with asyncio.TaskGroup() as tg:
10        tasks = [tg.create_task(sf.do("foo", compute), name=f"task-{i+1:02}") for i in range(5)]
11
12    for task in tasks:
13        result, shared = task.result()
14        print(f"{task.get_name()}: {result}, shared = {shared}")
15
16
17asyncio.run(main())
18# Output:
19# compute() called
20# task-01: 65, shared = True
21# task-02: 65, shared = True
22# task-03: 65, shared = True
23# task-04: 65, shared = True
24# task-05: 65, shared = True

As you can see, the code works as expected: compute() has only been called once by one of the tasks, while the other 4 waited, and in the end they all got the same result. Awesome!

There’s one more thing we haven’t covered: forgetting the calls. Here’s the implementation of forget():

1async def forget(self, key: str) -> None:
2    async with self._lock:
3        if key in self._calls:
4            del self._calls[key]

That’s it.

Now, what this allows us to do is to “forget” about in-flight calls. It’s easier to understand this concept in practice so let’s expand our previous example. What if we wanted to issue a “fresh” call to compute(), even when there are already other coroutines waiting for the result under the same key? Simple: just forget() the key:

 1sf = Group()
 2async with asyncio.TaskGroup() as tg:
 3    first = [tg.create_task(sf.do("foo", compute), name=f"task-{i+1:02}") for i in range(5)]
 4
 5    # Sleep 1 second just to ensure all tasks have been spawned
 6    await asyncio.sleep(1)
 7    await sf.forget("foo")
 8
 9    # Now spawn 5 more tasks calling compute() under the same key
10    second = [tg.create_task(sf.do("foo", compute), name=f"task-{i+1:02}") for i in range(5)]
11
12print("first batch:")
13for task in first:
14    result, shared = task.result()
15    print(f"{task.get_name()}: {result}, shared = {shared}")
16
17print("second batch:")
18for task in second:
19    result, shared = task.result()
20    print(f"{task.get_name()}: {result}, shared = {shared}")
21
22# Output:
23# compute() called
24# compute() called
25# first batch:
26# task-01: 30, shared = True
27# task-02: 30, shared = True
28# task-03: 30, shared = True
29# task-04: 30, shared = True
30# task-05: 30, shared = True
31# second batch:
32# task-01: 97, shared = True
33# task-02: 97, shared = True
34# task-03: 97, shared = True
35# task-04: 97, shared = True
36# task-05: 97, shared = True

This time, we see compute() being called twice, by two groups of tasks, in which each coroutine shared the same result.

NOTE: perhaps this might be obvious to the reader, I was initially confused by the name forget. The Group does not store the return values forever: once the function returns the result, do() will remove the call from the dictionary. To forget in this case means to drop the reference to an in-progress call.

Error handling

In Go errors are values and thus Do() returns 3 elements instead of 2, one of which being an error. In Python, on the other hand, we use exceptions. With current implementation of Group, if an exception is raised in the function that we passed to do(), it will be propagated to the coroutine that called it first, effectively terminating it before it gets to mark the done event, leaving other coroutines waiting forever:

 1async def compute() -> int:
 2    print("compute() called")
 3    await asyncio.sleep(3)
 4    raise RuntimeError("hello")
 5
 6
 7async def job(sf: Group) -> None:
 8    try:
 9        await sf.do("foo", compute)
10    except RuntimeError as e:
11        print(f"error: {e}")
12
13
14async def main():
15    sf = Group()
16    async with asyncio.TaskGroup() as tg:
17        for _ in range(5):
18            tg.create_task(job(sf))
19
20    print("all tasks finished")  # <<< unreachable because 'async with' is never exited
21
22
23asyncio.run(main())
24
25# Output:
26# compute() called
27# error: hello

Luckily, this is pretty easy to fix: if f() raises an exception, we will catch it, and re-raise for everyone who’s waiting for the result. Since we still want to let the caller know whether the result was shared with other coroutines, we will wrap the original exception in a custom class. Here’s what the updated code looks like:

 1class Call[T]:
 2    # can have either of these attributes but not both.
 3    value: T
 4    exc: BaseException
 5
 6    def __init__(self) -> None:
 7        self.callers = 0
 8        self.done = asyncio.Event()
 9
10    def result(self) -> T:
11        try:
12            return self.value
13        except AttributeError:
14            raise self.exc from None
15
16
17class DoError(Exception):
18    def __init__(self, exc: BaseException, shared: bool) -> None:
19        self.exc = exc
20        self.shared = shared
21
22    def __str__(self) -> str:
23        return str(self.exc)
24
25
26class Group:
27    # ... code omitted
28
29    async def do[R, **P](
30        self,
31        key: str,
32        f: Callable[P, Coroutine[None, None, R]],
33        *args: P.args,
34        **kwargs: P.kwargs,
35    ) -> tuple[R, bool]:
36        # ... code omitted
37
38        if key in self._calls:
39            # ... code omitted
40
41            await call.done.wait()
42            # Return the result or wrap a potential exception.
43            try:
44                return call.result(), True
45            except BaseException as exc:
46                raise DoError(exc, True) from None
47
48        # ... code omitted
49
50        # Catch any error that may occur in f
51        try:
52            call.value = await f(*args, **kwargs)
53        except BaseException as exc:
54            call.exc = exc
55
56        # ... code omitted
57
58        # Return the result or wrap a potential exception.
59        try:
60            return call.result(), call.callers > 0
61        except BaseException as exc:
62            raise DoError(exc, call.callers > 0) from None

And now, if we re-run our previous example, we get the following:

 1async def compute() -> int:
 2    print("compute() called")
 3    await asyncio.sleep(3)
 4    raise RuntimeError("hello")
 5
 6
 7async def job(sf: Group) -> None:
 8    try:
 9        await sf.do("foo", compute)
10    except DoError as e:
11        print(f"error: {e} (id = {id(e.exc)})")
12
13
14async def main():
15    sf = Group()
16    async with asyncio.TaskGroup() as tg:
17        for _ in range(5):
18            tg.create_task(job(sf))
19
20    print("all tasks finished")
21
22
23asyncio.run(main())
24
25# Output:
26# compute() called
27# error: hello (id = 4337637248)
28# error: hello (id = 4337637248)
29# error: hello (id = 4337637248)
30# error: hello (id = 4337637248)
31# error: hello (id = 4337637248)
32# all tasks finished

Everything works as expected: compute() is still only called once and all 5 coroutines receive the original exception.

Conclusion

In just 60 lines of code, we’ve been able to replicate Go’s singleflight.Group in Python. The end result looks a lot like the original Go version, minus panic recovery and some exception handling. I’ve had a lot of fun exploring this topic and I hope you enjoyed reading the post.

You can download the full code here.