From 2e83b4c5a7541232191fc8dab4f38c013a88dc29 Mon Sep 17 00:00:00 2001 From: Edward Hope-Morley Date: Tue, 7 May 2024 09:03:32 +0100 Subject: [PATCH] Optimise data transfer between principle and worker procs SearchResultMinimal is used to transder results between the worker processes and the main collector. As a result it must be a small as possible in order to keep transfer fast and memory footprint low. There were several unnecessary variables being stored and duplicated in this object that have now been removed maing the transfer faster and using less memory. Transfers are also batched so as to reduce interruption on searches. Also removes unnecssary use of multiprocessing.Queue for single thread usecase. --- searchkit/search.py | 392 +++++++++++++++++++++++++------------- tests/unit/test_search.py | 20 +- 2 files changed, 276 insertions(+), 136 deletions(-) diff --git a/searchkit/search.py b/searchkit/search.py index e480b5a..97273c3 100755 --- a/searchkit/search.py +++ b/searchkit/search.py @@ -1,5 +1,6 @@ import abc import concurrent.futures +import copy import glob import gzip import multiprocessing @@ -20,6 +21,7 @@ RESULTS_QUEUE_TIMEOUT = 60 MAX_QUEUE_RETRIES = 10 RS_LOCK = multiprocessing.Lock() +NUM_BUFFERED_RESULTS = 100 def rs_locked(f): @@ -114,6 +116,9 @@ def run(self, line): return ret + def __repr__(self): + return ', '.join([str(p) for p in self.patterns]) + class SequenceSearchDef(SearchDefBase): @@ -229,39 +234,73 @@ def remove(self, sid): del self.data[sid] -class ResultStoreBase(UserDict): +class ResultStoreBase(UserList): + """ + This class is used to de-duplicate values stored in search results such + that allowing their reference to be saved in the result for later lookup. + """ def __init__(self): - self.head = 0 - self.index = {} - self.meta = {} - self.data = {} + self.counters = {} + self.value_store = self.data = [] + self.tag_store = [] + self.sequence_id_store = [] def __getitem__(self, result_id): - return self.data.get(result_id) + if result_id >= len(self.value_store): + return None - def increment_head(self): - """ Incrementing differs for proxied vs. raw types so we leave this to - implementations to figure out. """ - self.head += 1 + return self.value_store[result_id] @property - def num_deduped(self): - counters = self.meta.values() + def parts_deduped(self): + counters = self.counters.values() return sum(counters) - len(counters) - def add(self, value): - _id = self.index.get(value) - if _id is not None: - self.meta[_id] += 1 - return _id + @property + def parts_non_deduped(self): + return len(self.value_store) + + def _get_value_index(self, value, store): + """ + If the value is not None and not already saved, save it. + + Returns the position of the value in the store or -1 if it is None. + """ + if value is None: + return -1 + + if value in store: + return store.index(value) + + store.append(value) + return len(store) - 1 - _id = self.head - self.data[_id] = value - self.meta[_id] = 1 - self.index[value] = _id - self.increment_head() - return _id + def add(self, tag, sequence_id, value): + """ + Ensure the given values are saved in the store and return their + position. + + Returns a tuple of references to the position in the store of each + value. A ref value of -1 indicates that the value does not exist and is + not stored. + + @param tag: optional search tag + @param sequence_id: optional sequence search id + @param value: search result value + """ + value_idx = self._get_value_index(value, self.value_store) + if value_idx >= 0: + # increment global counter + if value_idx not in self.counters: + self.counters[value_idx] = 1 + else: + self.counters[value_idx] += 1 + + tag_idx = self._get_value_index(tag, self.tag_store) + sequence_id_idx = self._get_value_index(sequence_id, + self.sequence_id_store) + return tag_idx, sequence_id_idx, value_idx class ResultStoreSimple(ResultStoreBase): @@ -272,30 +311,28 @@ class ResultStoreParallel(ResultStoreBase): """ Store for use when sharing between processes is required. """ def __init__(self, mgr): - self._head = mgr.Value('i', 0) - self.meta = mgr.dict() - self.index = mgr.dict() - self.data = mgr.dict() + self.counters = mgr.dict() + self.value_store = self.data = mgr.list() + self.tag_store = mgr.list() + self.sequence_id_store = mgr.list() @rs_locked def __getitem__(self, result_id): return super().__getitem__(result_id) - @property - def head(self): - return self._head.value - - def increment_head(self): - self._head.value = self.head + 1 + @rs_locked + def add(self, *args, **kwargs): + return super().add(*args, **kwargs) + @property @rs_locked - def add(self, value): - return super().add(value) + def parts_deduped(self): + return super().parts_deduped @property @rs_locked - def num_deduped(self): - return super().num_deduped + def parts_non_deduped(self): + return super().parts_non_deduped @rs_locked def unproxy_results(self): @@ -303,11 +340,11 @@ def unproxy_results(self): Converts internal stores to unproxied types so they can be accessed once their manager is gone. """ - log.debug("unproxying results store (data=%s)", len(self.data)) - self._head = self._head.value - self.data = self.data.copy() - self.meta = self.meta.copy() - self.index = self.index.copy() + log.debug("unproxying results store (data=%s)", len(self.value_store)) + self.value_store = self.data = copy.deepcopy(self.data) + self.tag_store = copy.deepcopy(self.tag_store) + self.sequence_id_store = copy.deepcopy(self.sequence_id_store) + self.counters = self.counters.copy() class ResultFieldInfo(UserDict): @@ -345,65 +382,108 @@ def index_to_name(self, index): class SearchResultBase(UserList): + META_OFFSET_TAG = 0 + META_OFFSET_SEQ_ID = 1 - def get(self, field): - """ - Retrieve result part value by index or name. + PART_OFFSET_IDX = 0 + PART_OFFSET_VALUE = 1 + PART_OFFSET_FIELD = 2 - @param field: integer index of string field name. - """ + @abc.abstractmethod + def __init__(self): + self.results_store = None + self.linenumber = None + self.section_id = None + + def _get_store_id(self, field): for part in self.data: store_id = None - if isinstance(field, str): - if part['name'] == field: - store_id = part['store_id'] - elif part['idx'] == field: - store_id = part['store_id'] + # Entry has format: (, , , ) + if len(part) > self.PART_OFFSET_FIELD and isinstance(field, str): + if part[self.PART_OFFSET_FIELD] != field: + continue + elif part[self.PART_OFFSET_IDX] != field: + continue + store_id = part[self.PART_OFFSET_VALUE] if store_id is not None: - return self.results_store.get(store_id) + return store_id - def __getattr__(self, name): - if name != 'field_info': - if self.field_info and name in self.field_info: - return self.get(name) + def get(self, field): + """ + Retrieve result part value by index or name. - raise AttributeError("'{}' object has no attribute '{}'". - format(self.__class__.__name__, name)) + @param field: integer index of string field name. + """ + store_id = self._get_store_id(field) + if store_id is not None: + return self.results_store[store_id] def __iter__(self): """ Only return part values when iterating over this object. """ for part in self.data: - yield self.results_store.get(part['store_id']) + yield self.results_store[part[self.PART_OFFSET_VALUE]] def __repr__(self): - r_list = ["{}='{}'".format(rp['idx'], - self.results_store.get(rp['store_id'])) + r_list = ["{}='{}'". + format(rp[self.PART_OFFSET_IDX], + self.results_store[rp[self.PART_OFFSET_VALUE]]) for rp in self.data] return ("ln:{} {} (section={})". - format(self.linenumber, ", ".join(r_list), - self.section_id)) + format(self.linenumber, ", ".join(r_list), self.section_id)) class SearchResultMinimal(SearchResultBase): - def __init__(self, result_id, data, linenumber, source_id, tag, - sequence_id, sequence_section_id, field_info): + def __init__(self, data, metadata, linenumber, source_id, + sequence_section_id, field_info): """ This is a minimised representation of a SearchResult object so as to reduce its size as much as possible before putting on the results queue. + + IMPORTANT: this class must contain as few attributes as possible and + their values must be as small as possible. For large values that need + to be shared, we de-duplicate using ResultStoreBase implementations. + + Do not store references to shared objects. """ - self.id = result_id - self.data = data + self.data = data[:] + self.metadata = metadata[:] self.linenumber = linenumber self.source_id = source_id - self.tag = tag - self.sequence_id = sequence_id self.section_id = sequence_section_id - self.field_info = field_info + if field_info: + self.field_names = list(field_info) + else: + self.field_names = None + self.results_store = None + def __getattr__(self, name): + if name != 'field_names': + if self.field_names and name in self.field_names: + return self.get(name) + + raise AttributeError("'{}' object has no attribute '{}'". + format(self.__class__.__name__, name)) + + @property + def tag(self): + idx = self.metadata[self.META_OFFSET_TAG] + if idx < 0: + return + + return self.results_store.tag_store[idx] + + @property + def sequence_id(self): + idx = self.metadata[self.META_OFFSET_SEQ_ID] + if idx < 0: + return + + return self.results_store.sequence_id_store[idx] + def register_results_store(self, store): """ Register a ResultsStore with this result. This is used to re-register @@ -445,7 +525,7 @@ def __init__(self, linenumber, source_id, result, search_def, self.field_info = search_def.field_info if not search_def.store_result_contents: - log.debug("store_contents is False - skipping save") + log.debug("store_contents is False - skipping save value") return self.store_result(result) @@ -464,34 +544,34 @@ def store_result(self, result): "(tag=%s)", self.tag) self._save_part(0, result.group(0)) - @cached_property - def id(self): - """ Unique Result ID """ - id_string = "{}-{}-{}".format(uuid.uuid4(), self.source_id, - self.linenumber) - if self.sequence_id: - id_string = "{}-{}-{}".format(id_string, - self.sequence_id, - self.section_id) - return id_string - - def _save_part(self, part_index, value): + @property + def metadata(self): + tag_id, seq_id, _ = self.results_store.add(self.tag, self.sequence_id, + value=None) + return (tag_id, seq_id) + + def _save_part(self, part_index, value=None): name = None - if value is not None and self.field_info: - name = self.field_info.index_to_name(part_index - 1) - value = self.field_info.ensure_type(name, value) + if value is not None: + if self.field_info: + name = self.field_info.index_to_name(part_index - 1) + value = self.field_info.ensure_type(name, value) + + _, _, store_id = self.results_store.add(self.tag, self.sequence_id, + value) + if name is None: + entry = (part_index, store_id) + else: + entry = (part_index, store_id, name) - store_id = self.results_store.add(value) - self.data.append({'idx': part_index, 'store_id': store_id, - 'name': name}) + self.data.append(entry) @cached_property def export(self): """ Export the smallest possible representation of this object. """ - return SearchResultMinimal(self.id, self.data, self.linenumber, - self.source_id, self.tag, - self.sequence_id, self.section_id, - self.field_info) + return SearchResultMinimal(self.data, self.metadata, + self.linenumber, self.source_id, + self.section_id, self.field_info) class SearchResultsCollection(UserDict): @@ -504,19 +584,19 @@ def __init__(self, search_catalog, results_store): @property def data(self): results = {} - for path, ids in self._results_by_path.items(): - results[path] = [self._results_by_id[id] for id in ids] + for path, _results in self._results_by_path.items(): + results[path] = _results return results @property def all(self): - for r in self._results_by_id.values(): - yield r + for results in self._results_by_path.values(): + for r in results: + yield r def reset(self): self._results_by_path = {} - self._results_by_id = {} @property def files(self): @@ -526,16 +606,14 @@ def add(self, result): result.register_results_store(self.results_store) # resolve path = self.search_catalog.source_id_to_path(result.source_id) - self._results_by_id[result.id] = result if path not in self._results_by_path: - self._results_by_path[path] = [result.id] + self._results_by_path[path] = [result] else: - self._results_by_path[path].append(result.id) + self._results_by_path[path].append(result) def find_by_path(self, path): """ Return results for a given path. """ - results = self._results_by_path.get(path, []) - return [self._results_by_id[id] for id in results] + return self._results_by_path.get(path, []) def find_by_tag(self, tag, path=None): """ Return results matched by tag. @@ -575,7 +653,7 @@ def _get_all_sequence_results(self, path=None): if result.sequence_id is None: continue - sequences.append(result.id) + sequences.append(result) return sequences @@ -606,8 +684,7 @@ def find_sequence_sections(self, sequence_obj, path=None): @param path: optionally filter results for a given path. """ _results = {} - for r in self._get_all_sequence_results(path=path): - result = self._results_by_id[r] + for result in self._get_all_sequence_results(path=path): s_id = result.sequence_id if s_id != sequence_obj.id: continue @@ -768,14 +845,14 @@ def source_id_to_path(self, s_id): log.error('\n'.join(list(self._source_ids.keys()))) def _get_source_id(self, path): - for source_id, _path in self._source_ids.items(): - if _path == path: - return source_id + if not self._source_ids: + source_id = 0 + else: + for source_id, _path in self._source_ids.items(): + if _path == path: + return source_id - source_id = str(uuid.uuid4()) - while source_id in self._source_ids: - log.error("source id %s already exists - trying again", source_id) - source_id = str(uuid.uuid4()) + source_id = max(list(self._source_ids)) + 1 log.debug("path=%s source_id=%s", path, source_id) self._source_ids[source_id] = path @@ -788,19 +865,52 @@ def __iter__(self): for entry in self._entries.values(): yield entry + def __repr__(self): + info = "" + for path, searches in self._entries.items(): + info += "\n{}:\n ".format(path) + entries = [] + for key, val in searches.items(): + entries.append("{}={}".format(key, val)) + + info += '\n '.join(entries) + + return info + class SearchTask(object): - def __init__(self, info, constraints_manager, results_queue, - results_store, decode_errors=None): + def __init__(self, info, constraints_manager, results_store, + results_queue=None, results_collection=None, + decode_errors=None): + """ + Run search task on file. + + @param info: dictionary containing information about what we are + searching incl. path, searchdefs etc. + @param constraints_manager: SearchConstraintsManager object + @param results_store: ResultStoreSimple or ResultStoreParallel object + @param results_queue: optional multiprocessing.Queue. This must be + provided if task is running in a child process. + @param results_collection: optional SearchResultsCollection. This must + be provided if task is not running in + parallel mode. + @param decode_errors: unicode decode error handling. + """ + if results_queue is not None and results_collection is not None: + raise Exception("only one of results_queue and results_collection " + "can be used with a SearchTask.") + self.info = info self.stats = SearchTaskStats() self.constraints_manager = constraints_manager self.results_queue = results_queue + self.results_collection = results_collection self.results_store = results_store self.decode_kwargs = {} if decode_errors: self.decode_kwargs['errors'] = decode_errors + self.buffered_results = [] @cached_property def id(self): @@ -822,13 +932,17 @@ def search_defs(self): def put_result(self, result): self.stats['results'] += 1 + if self.results_collection is not None: + self.results_collection.add(result) + return + max_tries = MAX_QUEUE_RETRIES while max_tries > 0: try: if max_tries == MAX_QUEUE_RETRIES: - self.results_queue.put_nowait(result.export) + self.results_queue.put_nowait(result) else: - self.results_queue.put(result.export, + self.results_queue.put(result, timeout=RESULTS_QUEUE_TIMEOUT) break @@ -849,6 +963,13 @@ def put_result(self, result): log.error("exceeded max number of retries (%s) to put results " "data on the queue", MAX_QUEUE_RETRIES) + def _flush_results_buffer(self): + # log.debug("flushing results buffer (%s)", len(self.buffered_results)) + for result in self.buffered_results: + self.put_result(result) + + self.buffered_results = [] + def _simple_search(self, search_def, line, ln): """ Perform a simple search on line. @@ -860,8 +981,11 @@ def _simple_search(self, search_def, line, ln): if not ret: return - self.put_result(SearchResult(ln, self.info['source_id'], ret, - search_def, self.results_store)) + result = SearchResult(ln, self.info['source_id'], ret, search_def, + self.results_store) + self.buffered_results.append(result.export) + if len(self.buffered_results) >= NUM_BUFFERED_RESULTS: + self._flush_results_buffer() def _sequence_search(self, seq_def, line, ln, sequence_results): """ Perform a sequence search on line. @@ -968,7 +1092,9 @@ def _process_sequence_results(self, sequence_results, current_ln): if r.section_id in filter_section_id[seq_id]: continue - self.put_result(r) + self.buffered_results.append(r.export) + if len(self.buffered_results) >= NUM_BUFFERED_RESULTS: + self._flush_results_buffer() def _run_search(self, fd): """ @@ -1007,6 +1133,9 @@ def _run_search(self, fd): self._simple_search(s_def, line, ln) self._process_sequence_results(sequence_results, ln) + # allow garbage collect + sequence_results = {} + log.debug("completed search of %s lines", self.stats['lines_searched']) if self.search_defs_conditional: msg = "constraints stats {}:".format(fd.name) @@ -1043,9 +1172,11 @@ def execute(self): fd.read(1) fd.seek(0) stats = self._run_search(fd) + self._flush_results_buffer() except OSError: with open(path, 'rb') as fd: stats = self._run_search(fd) + self._flush_results_buffer() except UnicodeDecodeError: log.exception("caught UnicodeDecodeError while searching %s", path) raise @@ -1076,7 +1207,8 @@ def reset(self): 'jobs_completed': 0, 'total_jobs': 0, 'results': 0, - 'num_deduped': 0} + 'parts_deduped': 0, + 'parts_non_deduped': 0} def update(self, stats): # pylint: disable=W0221 if not stats: @@ -1343,23 +1475,22 @@ def _ensure_worker_processes_killed(self): except Exception: log.debug('worker process %s already killed', wpid) - def _run_single(self, results, results_store): + def _run_single(self, results_collection, results_store): """ Run a single search using this process. - @param results: SearchResultsCollection object + @param results_collection: SearchResultsCollection object + @param results_store: ResultsStoreSimple object """ - results_queue = multiprocessing.Queue() for info in self.catalog: task = SearchTask(info, constraints_manager=self.constraints_manager, - results_queue=results_queue, + results_collection=results_collection, results_store=results_store, decode_errors=self.decode_errors) self.stats.update(task.execute()) self.stats['jobs_completed'] = 1 self.stats['total_jobs'] = 1 - self._purge_results(results, results_queue, self.stats['results']) def _run_mp(self, mgr, results, results_store): """ Run searches in parallel. @@ -1438,20 +1569,23 @@ def run(self): for p in self.catalog]) self.stats['searches_by_job'] = [len(p['searches']) for p in self.catalog] + log.debug(repr(self.catalog)) if len(self.files) > 1: log.debug("running searches (parallel=True)") with multiprocessing.Manager() as mgr: rs = ResultStoreParallel(mgr) results = SearchResultsCollection(self.catalog, rs) self._run_mp(mgr, results, rs) - self.stats['num_deduped'] = rs.num_deduped + self.stats['parts_deduped'] = rs.parts_deduped + self.stats['parts_non_deduped'] = rs.parts_non_deduped rs.unproxy_results() else: log.debug("running searches (parallel=False)") rs = ResultStoreSimple() results = SearchResultsCollection(self.catalog, rs) self._run_single(results, rs) - self.stats['num_deduped'] = rs.num_deduped + self.stats['parts_deduped'] = rs.parts_deduped + self.stats['parts_non_deduped'] = rs.parts_non_deduped log.debug("filesearcher: completed (%s)", self.stats) return results diff --git a/tests/unit/test_search.py b/tests/unit/test_search.py index a98bc54..39a0cbb 100644 --- a/tests/unit/test_search.py +++ b/tests/unit/test_search.py @@ -228,6 +228,8 @@ def test_simple_search_many_files(self): for i in range(1000): with open(os.path.join(dtmp, str(i)), 'w') as fd: fd.write("a key: foo bar\n") + for i in range(1000): + fd.write("some extra text\n") fd.write("a key: bar foo\n") f.add(SearchDef(r'.+:\s+(\S+) \S+', tag='simple'), dtmp + '/*') @@ -339,7 +341,8 @@ def test_large_sequence_search(self): finally: shutil.rmtree(dtmp) - self.assertEqual(f.stats['num_deduped'], 40037) + self.assertEqual(f.stats['parts_deduped'], 40037) + self.assertEqual(f.stats['parts_non_deduped'], 3) self.assertEqual(len(results), 40040) self.assertEqual(len(results.find_by_tag('simple')), 20000) @@ -907,14 +910,17 @@ def test_logs_since_hours_sd(self): def test_search_result_index(self): sri = ResultStoreSimple() for val in ['foo', 'bar', 'foo']: - sri.add(val) + sri.add('atag', None, val) - self.assertEqual(sri, {0: 'foo', 1: 'bar'}) - self.assertEqual(sri.meta, {0: 2, 1: 1}) + self.assertEqual(sri, ['foo', 'bar']) + self.assertEqual(sri.counters, {0: 2, 1: 1}) + self.assertEqual(sri.tag_store, ['atag']) self.assertEqual(sri[0], 'foo') - self.assertEqual(sri.get(1), 'bar') - self.assertEqual(sri.get(2), None) - self.assertEqual(sri.num_deduped, 1) + + self.assertEqual(sri[1], 'bar') + self.assertEqual(sri[2], None) + self.assertEqual(sri.parts_deduped, 1) + self.assertEqual(sri.parts_non_deduped, 2) def test_search_unicode_decode_w_error(self): f = FileSearcher()