Source code for pycognaize.common.langchain_loader

import importlib
from pycognaize.common.decorators import module_not_found
from pycognaize.common.enums import PageLayoutEnum
from pycognaize.document.document import Document
from pycognaize.document.field import Field, TableField


[docs] class LangchainLoader: """ Convert Pycognaize Document Object to Langchain Document Object """ INPUT_FIELDS: list[str] = list(i.value for i in PageLayoutEnum) OVERLAP = 512 LIMIT = 2048 @module_not_found() def __init__(self, document: Document) -> None: """The constructor takes a cognaize document object as an input. The cognaize document object should have page layout and table data in document.x, the pythonnames for each field is defined as a class attribute. """ transformers = importlib.import_module('transformers') self.document = document self.tokenizer = transformers.GPT2TokenizerFast.from_pretrained("gpt2") @module_not_found() def load_and_split(self): """ load and split the Document into separate langchain document objects """ langchain = importlib.import_module('langchain') text_blocks_w_metadata: list[list[tuple[dict, str]]] =\ self._get_as_text(self.document) metadata_list, text_list = self._create_text_and_metadata( text_blocks_w_metadata, document_id=self.get_document_src()) text_splitter = langchain.text_splitter.RecursiveCharacterTextSplitter( chunk_size=self.LIMIT, chunk_overlap=self.OVERLAP, length_function=self.count_tokens) docs = text_splitter.create_documents(texts=text_list, metadatas=metadata_list) return docs
[docs] def get_document_src(self) -> str: """Get the SHA of the document""" return self.document.metadata['src']
[docs] def count_tokens(self, text: str) -> int: """Tokenize the text and count the number of tokens""" return len(self.tokenizer.encode(text))
@staticmethod def _get_table_group(current_group: list[tuple[dict, str]], metadata, text_block): if current_group and current_group[-1][0]['block'] in \ (PageLayoutEnum.PAGE_HEADER, PageLayoutEnum.SECTION_HEADER): table_group = [current_group[-1], (metadata, text_block)] current_group = current_group[:-1] else: table_group = [(metadata, text_block)] return table_group, current_group def _get_as_text(self, document: Document) -> list[list[tuple[dict, str]]]: """Given a cognaize document object, return a list of strings, where each string represents a single chunk/block. The tables and text are always in separate groups. """ fields = [] for pname in self.INPUT_FIELDS: if pname not in document.x: continue block_fields = [(pname, i) for i in document.x[pname]] fields.extend(block_fields) ordered_blocks_w_metadata: list[ tuple[dict, str]] = self._order_text_blocks(fields) groups = [] current_group = None for block_n, (metadata, text_block) in \ enumerate(ordered_blocks_w_metadata): # First block, create a new group if block_n == 0: current_group = [(metadata, text_block)] # Table, create a separate group if metadata['block'] is PageLayoutEnum.TABLE: table_group, current_group =\ self._get_table_group(current_group, metadata, text_block) # finalize the previous group groups.append(current_group) # finalized the table group groups.append(table_group) # create an empty group current_group = [] elif metadata['block'] in (PageLayoutEnum.PAGE_HEADER, PageLayoutEnum.SECTION_HEADER): # If all elements are headers in the group, continue with the # same group if all(( g[0]['block'] in (PageLayoutEnum.PAGE_HEADER, PageLayoutEnum.SECTION_HEADER) for g in current_group )): current_group.append((metadata, text_block)) elif current_group: groups.append(current_group) current_group = [(metadata, text_block)] else: current_group = [(metadata, text_block)] # Create a new group elif block_n == len(ordered_blocks_w_metadata) - 1 and\ current_group: groups.append(current_group) else: current_group.append((metadata, text_block)) return groups @staticmethod def _create_text_and_metadata( text_blocks_w_metadata: list[list[tuple[dict, str]]], document_id: str ) -> tuple[list[dict], list[str]]: """ Creates metadata and texts for creating LangChain Document objects """ metadata_list = [] text_list = [] for group in text_blocks_w_metadata: group_text = '' group_metadata = {"pages": set(), "source": []} for metadata, text in group: tag = metadata['tag'] group_metadata['pages'].add(tag.page.page_number) group_metadata['source'].append((tag.left, tag.top, tag.right, tag.bottom, tag.page.page_number)) group_text += f'\n{text}' group_metadata['pages'] = list(group_metadata['pages']) group_metadata['document'] = document_id metadata_list.append(group_metadata) text_list.append(group_text) return metadata_list, text_list @staticmethod def _order_text_blocks(fields: list[tuple[str, Field]]) -> \ list[tuple[dict, str]]: """ Order the fields as they appear in the original document. Order by page, top coordinate and left coordinate in that order. """ text_blocks = [] filtered_fields = [] for block_type, field in fields: if isinstance(field, TableField) and field.tags: filtered_fields.append( ( { 'block': PageLayoutEnum(block_type), 'tag': field.tags[0], }, field ) ) else: if field.tags and field.tags[0].raw_value.strip(): filtered_fields.append( ( { 'block': PageLayoutEnum(block_type), 'tag': field.tags[0], }, field ) ) fields: list[tuple[dict, Field]] = sorted( filtered_fields, key=lambda x: (x[1].tags[0].page.page_number, x[1].tags[0].top, x[1].tags[0].left)) for metadata, field in fields: if isinstance(field, TableField): value: str = field.tags[0].to_string() else: value = field.value text_blocks.append((metadata, value)) return text_blocks