スポンサーリンク

【LangChain】CSVLoaderでCSVファイルの項目をmetadataとして設定できるようにする

記事内に広告が含まれています。

ChatGPTに外部データをもとにした回答生成させるために、ベクトルデータベースを作成していました。CSVファイルのある列をベクトル化し、ある列をメタデータ(metadata)に設定したかったのですが、CSVLoaderクラスのload関数ではできなかったため、変更を加えました。

以下のコードはcsv_loader.pyのCSVLoaderクラスの内容です。

import csv
from typing import Dict, List, Optional

from langchain.docstore.document import Document
from langchain.document_loaders.base import BaseLoader


class CSVLoader(BaseLoader):
    """Loads a CSV file into a list of documents.

    Each document represents one row of the CSV file. Every row is converted into a
    key/value pair and outputted to a new line in the document's page_content.

    The source for each document loaded from csv is set to the value of the
    `file_path` argument for all doucments by default.
    You can override this by setting the `source_column` argument to the
    name of a column in the CSV file.
    The source of each document will then be set to the value of the column
    with the name specified in `source_column`.

    Output Example:
        .. code-block:: txt

            column1: value1
            column2: value2
            column3: value3
    """

    def __init__(
        self,
        file_path: str,
        source_column: Optional[str] = None,
        csv_args: Optional[Dict] = None,
        encoding: Optional[str] = None,
    ):
        self.file_path = file_path
        self.source_column = source_column
        self.encoding = encoding
        self.csv_args = csv_args or {}

    def load(self) -> List[Document]:
        """Load data into document objects."""

        docs = []
        with open(self.file_path, newline="", encoding=self.encoding) as csvfile:
            csv_reader = csv.DictReader(csvfile, **self.csv_args)  # type: ignore
            for i, row in enumerate(csv_reader):
                content = "\n".join(f"{k.strip()}: {v.strip()}" for k, v in row.items())
                try:
                    source = (
                        row[self.source_column]
                        if self.source_column is not None
                        else self.file_path
                    )
                except KeyError:
                    raise ValueError(
                        f"Source column '{self.source_column}' not found in CSV file."
                    )
                metadata = {"source": source, "row": i}
                doc = Document(page_content=content, metadata=metadata)
                docs.append(doc)

        return docs

以下のpage_contentとmetadataに、CSVファイル中の任意の列の値を設定したいため、コードを変更します。

doc = Document(page_content=content, metadata=metadata)

変更後のコードは以下のとおりです。

CSVLoaderのインスタンス生成時、page_contentに設定したいCSVファイル中のカラム名をtext_columnに、metadataに設定したいCSVファイル中のカラム名をmeta_columnsに設定します。

import csv
from typing import Dict, List, Optional
from langchain.docstore.document import Document


class CSVLoader:

    def __init__(
        self,
        file_path: str,
        text_column: str,
        source_column: Optional[str] = None,
        meta_columns: Optional[list] = None,
        csv_args: Optional[Dict] = None,
        encoding: Optional[str] = None,
    ):
        self.file_path = file_path
        self.source_column = source_column
        self.text_column = text_column
        self.meta_columns = meta_columns
        self.encoding = encoding
        self.csv_args = csv_args or {}

    def load(self) -> List[Document]:

        docs = []
        with open(self.file_path, newline="", encoding=self.encoding) as csvfile:
            csv_reader = csv.DictReader(csvfile, **self.csv_args)  # type: ignore
            for i, row in enumerate(csv_reader):
                content = row[self.text_column]
                metadata = {}

                try:
                    source = (
                        row[self.source_column]
                        if self.source_column is not None
                        else self.file_path
                    )

                    if self.meta_columns is not None:
                        metadata = {col: row[col] for col in self.meta_columns}

                except KeyError:
                    raise ValueError(
                        f"Some columns are not found in CSV file."
                    )

                metadata["source"] = source
                metadata["row"] = i
                doc = Document(page_content=content, metadata=metadata)
                docs.append(doc)

        return docs

以下の例では、上記コードを記載したloader.pyからCSVLoaderを呼び出しています。test.csvのcomment列をベクトル化対象列とし、year列をmetadata設定対象列にできます。

from loader import CSVLoader

loader = CSVLoader(file_path='test.csv',
                   text_column='comment',
                   meta_columns=['year'],
                   encoding='utf-8',
                   csv_args={"delimiter": ','})
data = loader.load()

以下の環境で簡単に動作確認をしましたが、実行結果が正しいことの保証は出来かねますので検証の上お使いください。

python:3.9.13
langchain:0.0.195

スポンサーリンク
スポンサーリンク
ChatGPTGenerative AILangChain
著者SNS
タイトルとURLをコピーしました