RAG with Gemma using wikipedia api

  • 运行环境: windows11 CPU 12th Gen Intel(R) Core(TM) i5-12500H 运行内存16G
  • Gemma:轻量级LLM
  • 基于langchain实现,可视化使用smith.langchain

Setup

  1. 创建conda环境,安装依赖

    • pytorch 2.1.2
    • scikit-learn 1.4.0
    • llama_cpp_python 0.2.52
    • langchain 0.1.8
    • langchain-community 0.0.21
    • langchain-experimental 0.0.52
    • tiktoken 0.6.0
    • sentence-transformers 2.3.1
    • numpy 1.26.4
    • pandas 2.1.4
    • wikipedia
  2. 下载gemma模型

    模型下载:gemma-2b-it-q4_k_m.gguf · lmstudio-ai/gemma-2b-it-GGUF at main (huggingface.co)

实现

RAG 的过程概览

详细参考上一篇博客与Langchain demo overview

代码实现

使用Gemma,并基于wikipedia

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
from langchain_community.retrievers import WikipediaRetriever
from langchain_community.llms import LlamaCpp
from langchain import hub
from langchain.prompts import PromptTemplate
from typing import List
from langchain_core.documents import Document
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import (
RunnablePassthrough,
)
from langchain.prompts.chat import ChatPromptTemplate
import os
#uncomment this to visual the processs on smithlangchain
# os.environ['LANGCHAIN_TRACING_V2'] = 'true'
# os.environ['LANGCHAIN_ENDPOINT'] = 'https://api.smith.langchain.com'
# os.environ['LANGCHAIN_API_KEY'] = '' #get this from https://api.smith.langchain.com

# llm model
model_path = "model/gemma-2b-it-q4_k_m.gguf" #path to the pretrained gemma model
llm = LlamaCpp(
model_path=model_path,
streaming=False,
n_gpu_layers=30
)

# retrieve from Wikipedia using the WikipediaRetriever
wiki = WikipediaRetriever(top_k_results=2, doc_content_chars_max=550) #参数限制了获取数量及其最大长度
# Post-processing
def format_docs(docs: List[Document]) -> str:
"""Convert Documents to a single string.:"""
formatted = [
f"Article Title: {doc.metadata['title']}\nArticle Snippet: {doc.page_content}"
for doc in docs
]
return "\n\n" + "\n\n".join(formatted)
# Prompt
#from the langchainhub
rag_prompt = hub.pull("rlm/rag-prompt")
#self-defined
# rag_prompt = ChatPromptTemplate.from_messages(
# [
# (
# "system",
# "You're a helpful AI assistant. Given a user question and some Wikipedia article snippets, answer the user question. If none of the articles answer the question, just say you don't know.\n\nHere are the Wikipedia articles:{context}",
# ),
# ("human", "{question}"),
# ]
# )

# Chain
rag_chain = (
{"context": wiki | format_docs, "question": RunnablePassthrough()}
| rag_prompt
| llm
| StrOutputParser()
)
# Question
print(rag_chain.invoke("How fast are cheetahs?"))

注:代码运行过程中,访问维基百科,某些ip会被拒绝

  • 运行输出

    answer: The cheetah is capable of running at 93 to 104 km/h (58 to 65 mph).

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    56
    57
    58
    59
    60
    61
    62
    63
    64
    65
    66
    67
    68
    69
    70
    71
    72
    73
    74
    75
    76
    77
    78
    79
    80
    81
    82
    83
    84
    85
    86
    87
    88
    89
    90
    llama_model_loader: loaded meta data with 21 key-value pairs and 164 tensors from model/gemma-2b-it-q4_k_m.gguf (version GGUF V3 (latest))
    llama_model_loader: Dumping metadata keys/values. Note: KV overrides do not apply in this output.
    llama_model_loader: - kv 0: general.architecture str = gemma
    llama_model_loader: - kv 1: general.name str = gemma-2b-it
    llama_model_loader: - kv 2: gemma.context_length u32 = 8192
    llama_model_loader: - kv 3: gemma.block_count u32 = 18
    llama_model_loader: - kv 4: gemma.embedding_length u32 = 2048
    llama_model_loader: - kv 5: gemma.feed_forward_length u32 = 16384
    llama_model_loader: - kv 6: gemma.attention.head_count u32 = 8
    llama_model_loader: - kv 7: gemma.attention.head_count_kv u32 = 1
    llama_model_loader: - kv 8: gemma.attention.key_length u32 = 256
    llama_model_loader: - kv 9: gemma.attention.value_length u32 = 256
    llama_model_loader: - kv 10: gemma.attention.layer_norm_rms_epsilon f32 = 0.000001
    llama_model_loader: - kv 11: tokenizer.ggml.model str = llama
    llama_model_loader: - kv 12: tokenizer.ggml.bos_token_id u32 = 2
    llama_model_loader: - kv 13: tokenizer.ggml.eos_token_id u32 = 1
    llama_model_loader: - kv 14: tokenizer.ggml.padding_token_id u32 = 0
    llama_model_loader: - kv 15: tokenizer.ggml.unknown_token_id u32 = 3
    llama_model_loader: - kv 16: tokenizer.ggml.tokens arr[str,256128] = ["<pad>", "<eos>", "<bos>", "<unk>", ...
    llama_model_loader: - kv 17: tokenizer.ggml.scores arr[f32,256128] = [0.000000, 0.000000, 0.000000, 0.0000...
    llama_model_loader: - kv 18: tokenizer.ggml.token_type arr[i32,256128] = [3, 3, 3, 2, 1, 1, 1, 1, 1, 1, 1, 1, ...
    llama_model_loader: - kv 19: general.quantization_version u32 = 2
    llama_model_loader: - kv 20: general.file_type u32 = 15
    llama_model_loader: - type f32: 37 tensors
    llama_model_loader: - type q4_K: 109 tensors
    llama_model_loader: - type q6_K: 18 tensors
    llm_load_vocab: mismatch in special tokens definition ( 544/256128 vs 388/256128 ).
    llm_load_print_meta: format = GGUF V3 (latest)
    llm_load_print_meta: arch = gemma
    llm_load_print_meta: vocab type = SPM
    llm_load_print_meta: n_vocab = 256128
    llm_load_print_meta: n_merges = 0
    llm_load_print_meta: n_ctx_train = 8192
    llm_load_print_meta: n_embd = 2048
    llm_load_print_meta: n_head = 8
    llm_load_print_meta: n_head_kv = 1
    llm_load_print_meta: n_layer = 18
    llm_load_print_meta: n_rot = 256
    llm_load_print_meta: n_embd_head_k = 256
    llm_load_print_meta: n_embd_head_v = 256
    llm_load_print_meta: n_gqa = 8
    llm_load_print_meta: n_embd_k_gqa = 256
    llm_load_print_meta: n_embd_v_gqa = 256
    llm_load_print_meta: f_norm_eps = 0.0e+00
    llm_load_print_meta: f_norm_rms_eps = 1.0e-06
    llm_load_print_meta: f_clamp_kqv = 0.0e+00
    llm_load_print_meta: f_max_alibi_bias = 0.0e+00
    llm_load_print_meta: n_ff = 16384
    llm_load_print_meta: n_expert = 0
    llm_load_print_meta: n_expert_used = 0
    llm_load_print_meta: pooling type = 0
    llm_load_print_meta: rope type = 2
    llm_load_print_meta: rope scaling = linear
    llm_load_print_meta: freq_base_train = 10000.0
    llm_load_print_meta: freq_scale_train = 1
    llm_load_print_meta: n_yarn_orig_ctx = 8192
    llm_load_print_meta: rope_finetuned = unknown
    llm_load_print_meta: model type = 2B
    llm_load_print_meta: model ftype = Q4_K - Medium
    llm_load_print_meta: model params = 2.51 B
    llm_load_print_meta: model size = 1.39 GiB (4.75 BPW)
    llm_load_print_meta: general.name = gemma-2b-it
    llm_load_print_meta: BOS token = 2 '<bos>'
    llm_load_print_meta: EOS token = 1 '<eos>'
    llm_load_print_meta: UNK token = 3 '<unk>'
    llm_load_print_meta: PAD token = 0 '<pad>'
    llm_load_print_meta: LF token = 227 '<0x0A>'
    llm_load_tensors: ggml ctx size = 0.06 MiB
    llm_load_tensors: CPU buffer size = 1420.21 MiB
    ............................................................
    llama_new_context_with_model: n_ctx = 512
    llama_new_context_with_model: freq_base = 10000.0
    llama_new_context_with_model: freq_scale = 1
    llama_kv_cache_init: CPU KV buffer size = 9.00 MiB
    llama_new_context_with_model: KV self size = 9.00 MiB, K (f16): 4.50 MiB, V (f16): 4.50 MiB
    llama_new_context_with_model: CPU input buffer size = 0.08 MiB
    llama_new_context_with_model: CPU compute buffer size = 7.88 MiB
    llama_new_context_with_model: graph splits (measure): 1
    AVX = 1 | AVX_VNNI = 0 | AVX2 = 1 | AVX512 = 0 | AVX512_VBMI = 0 | AVX512_VNNI = 0 | FMA = 1 | NEON = 0 | ARM_FMA = 0 | F16C = 1 | FP16_VA = 0 | WASM_SIMD = 0 | BLAS = 0 | SSE3 = 1 | SSSE3 = 0 | VSX = 0 | MATMUL_INT8 = 0 |
    Model metadata: {'general.name': 'gemma-2b-it', 'general.architecture': 'gemma', 'gemma.context_length': '8192', 'gemma.block_count': '18', 'gemma.attention.head_count_kv': '1', 'gemma.embedding_length': '2048', 'gemma.feed_forward_length': '16384', 'gemma.attention.head_count': '8', 'gemma.attention.key_length': '256', 'gemma.attention.value_length': '256', 'gemma.attention.layer_norm_rms_epsilon': '0.000001', 'tokenizer.ggml.model': 'llama', 'general.quantization_version': '2', 'tokenizer.ggml.bos_token_id': '2', 'general.file_type': '15', 'tokenizer.ggml.eos_token_id': '1', 'tokenizer.ggml.padding_token_id': '0', 'tokenizer.ggml.unknown_token_id': '3'}
    The cheetah is capable of running at 93 to 104 km/h (58 to 65 mph).

    llama_print_timings: load time = 478.43 ms
    llama_print_timings: sample time = 41.70 ms / 28 runs ( 1.49 ms per token, 671.48 tokens per second)
    llama_print_timings: prompt eval time = 13944.05 ms / 378 tokens ( 36.89 ms per token, 27.11 tokens per second)
    llama_print_timings: eval time = 1394.02 ms / 27 runs ( 51.63 ms per token, 19.37 tokens per second)
    llama_print_timings: total time = 16688.10 ms / 405 tokens

    Process finished with exit code 0

smith.langchain查看过程

  • retriever获取documnment,如下为两个,对应

    wiki = WikipediaRetriever(top_k_results=2, doc_content_chars_max=550)

  • format_docs

  • 将检索内容集成的提问与回答

进阶

增强RAG过程,提高检索有效性等

  • 举例1

    来源:Citations | 🦜️🔗 LangChain

    对检索到的文档进行后处理以压缩内容,这样源内容已经足够少,我们就不需要模型来引用特定的来源或跨度。

    修改后的代码:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    56
    57
    58
    59
    60
    61
    62
    def Retrieval_post_processing():
    '''
    post-process our retrieved documents to compress the content, so that the source content is already minimal enough that we don’t need the model to cite specific sources or spans.
    Returns:

    '''
    from langchain.retrievers.document_compressors import EmbeddingsFilter
    from langchain_community.embeddings import LlamaCppEmbeddings
    from langchain.text_splitter import RecursiveCharacterTextSplitter
    wiki = WikipediaRetriever(top_k_results=3, doc_content_chars_max=1000)
    rag_prompt = ChatPromptTemplate.from_messages(
    [
    (
    "system",
    "You're a helpful AI assistant. Given a user question and some Wikipedia article snippets, answer the user question. If none of the articles answer the question, just say you don't know.\n\nHere are the Wikipedia articles:{context}",
    ),
    ("human", "{question}"),
    ]
    )
    def format_docs(docs: List[Document]) -> str:
    """Convert Documents to a single string.:"""
    formatted = [
    f"Article Title: {doc.metadata['title']}\nArticle Snippet: {doc.page_content}"
    for doc in docs
    ]
    return "\n\n" + "\n\n".join(formatted)

    format = itemgetter("docs") | RunnableLambda(format_docs)
    # subchain for generating an answer once we've done retrieval
    answer = rag_prompt | llm | StrOutputParser()

    splitter = RecursiveCharacterTextSplitter(
    chunk_size=500,
    chunk_overlap=0,
    separators=["\n\n", "\n", ".", " "],
    keep_separator=False,
    )
    compressor = EmbeddingsFilter(embeddings=LlamaCppEmbeddings(model_path=model_path), k=3)

    def split_and_filter(input) -> List[Document]:
    docs = input["docs"]
    question = input["question"]
    split_docs = splitter.split_documents(docs)
    stateful_docs = compressor.compress_documents(split_docs, question)
    return [stateful_doc for stateful_doc in stateful_docs]

    retrieve = (
    RunnableParallel(question=RunnablePassthrough(), docs=wiki) | split_and_filter
    )
    # docs = retrieve.invoke("How fast are cheetahs?")
    # for doc in docs:
    # print(doc.page_content)
    # print("\n\n")
    chain_4 = (
    RunnableParallel(question=RunnablePassthrough(), docs=retrieve)
    .assign(context=format)
    .assign(answer=answer)
    .pick(["answer", "docs"])
    )
    # Note the documents have an article "summary" in the metadata that is now much longer than the
    # actual document page content. This summary isn't actually passed to the model.
    print(chain_4.invoke("How fast are cheetahs?"))

应用RAG为其他任务赋能

如LLM+分类任务

搜索引擎赋能

https://morphic.sh/

https://github.com/miurla/morphic

参考

rag-from-scratch/rag_from_scratch_1_to_4.ipynb at main · langchain-ai/rag-from-scratch

Citations | 🦜️🔗 LangChain

rllm/examples/small_examples_with_LLM at main · rllm-project/rllm (github.com)