diff --git a/src/executorlib/standalone/batched.py b/src/executorlib/standalone/batched.py index 6adee8590..3c44c0c75 100644 --- a/src/executorlib/standalone/batched.py +++ b/src/executorlib/standalone/batched.py @@ -1,26 +1,28 @@ from concurrent.futures import Future -def batched_futures(lst: list[Future], skip_lst: list[list], n: int) -> list[list]: +def batched_futures( + lst: list[Future], nested_skip_lst: list[Future[list]], n: int +) -> list[list]: """ Batch n completed future objects. If the number of completed futures is smaller than n and the end of the batch is - not reached yet, then an empty list is returned. If n future objects are done, which are not included in the skip_lst + not reached yet, then an empty list is returned. If n future objects are done, which are not included in the skip_set then they are returned as batch. Args: lst (list): list of all future objects - skip_lst (list): list of previous batches of future objects + nested_skip_lst (list): nest list of individual results already assigned to previous batches n (int): batch size Returns: list: results of the batched futures """ - skipped_ids = {id(item) for items in skip_lst for item in items} + skip_set = {id(item) for f in nested_skip_lst for item in f.result()} done_lst = [] - n_expected = min(n, len(lst) - len(skipped_ids)) + n_expected = min(n, len(lst) - len(skip_set)) for v in lst: - if v.done() and id(v.result()) not in skipped_ids: + if v.done() and id(v.result()) not in skip_set: done_lst.append(v.result()) if len(done_lst) == n_expected: return done_lst diff --git a/src/executorlib/task_scheduler/interactive/dependency.py b/src/executorlib/task_scheduler/interactive/dependency.py index fe24c3787..349b3f1c7 100644 --- a/src/executorlib/task_scheduler/interactive/dependency.py +++ b/src/executorlib/task_scheduler/interactive/dependency.py @@ -346,7 +346,7 @@ def _update_waiting_task( done_lst = batched_futures( lst=task_wait_dict["kwargs"]["lst"], n=task_wait_dict["kwargs"]["n"], - skip_lst=[f.result() for f in task_wait_dict["kwargs"]["skip_lst"]], + nested_skip_lst=task_wait_dict["kwargs"]["skip_lst"], ) if len(done_lst) == 0: wait_tmp_lst.append(task_wait_dict) diff --git a/tests/unit/standalone/test_batched.py b/tests/unit/standalone/test_batched.py index 9b811d26a..31e3d578c 100644 --- a/tests/unit/standalone/test_batched.py +++ b/tests/unit/standalone/test_batched.py @@ -10,14 +10,18 @@ def test_batched_futures(self): f = Future() f.set_result(i) lst.append(f) - self.assertEqual(batched_futures(lst=lst, n=3, skip_lst=[]), [0, 1, 2]) - self.assertEqual(batched_futures(lst=lst, skip_lst=[[0, 1, 2]], n=3), [3, 4, 5]) - self.assertEqual(batched_futures(lst=lst, skip_lst=[[0, 1, 2], [3, 4, 5]], n=3), [6, 7, 8]) - self.assertEqual(batched_futures(lst=lst, skip_lst=[[0, 1, 2], [3, 4, 5], [6, 7, 8]], n=3), [9]) + batched_lst = [Future(), Future(), Future()] + batched_lst[0].set_result([0, 1, 2]) + batched_lst[1].set_result([3, 4, 5]) + batched_lst[2].set_result([6, 7, 8]) + self.assertEqual(batched_futures(lst=lst, n=3, nested_skip_lst=set()), [0, 1, 2]) + self.assertEqual(batched_futures(lst=lst, nested_skip_lst=batched_lst[:1], n=3), [3, 4, 5]) + self.assertEqual(batched_futures(lst=lst, nested_skip_lst=batched_lst[:2], n=3), [6, 7, 8]) + self.assertEqual(batched_futures(lst=lst, nested_skip_lst=batched_lst, n=3), [9]) def test_batched_futures_not_finished(self): lst = [] for _ in list(range(10)): f = Future() lst.append(f) - self.assertEqual(batched_futures(lst=lst, n=3, skip_lst=[]), []) + self.assertEqual(batched_futures(lst=lst, n=3, nested_skip_lst=set()), [])