使用 KerasNLP 快速开始 Gemma

本教程将向您展示如何使用 KerasNLP1 开始使用 Gemma。Gemma 是一个轻量级的、采用最新技术的开放模型家族,这些技术源于创建 Gemini 模型的同一研究和技术。KerasNLP 是一个自然语言处理(NLP)模型的集合,这些模型使用 Keras2 实现,并能在 JAX、PyTorch 和 TensorFlow 上运行。

在本教程中,您将使用 Gemma 生成对几个提示的文本响应。如果您对 Keras 不熟悉,可能会想在开始前阅读《 Keras 入门3》,但这不是必须的。通过本教程的学习,您将更加了解 Keras。

设置

Gemma 设置

要完成本教程,您首先需要按照 Gemma 设置4的说明完成设置。Gemma 设置说明将向您展示如何进行以下操作:

Gemma 模型由 Kaggle 托管。要使用 Gemma,请在 Kaggle 上请求访问权限:

  • 在 kaggle.com5 登录或注册。
  • 打开 Gemma 模型卡片6并选择“请求访问权限”。
  • 完成同意表格并接受条款和条件。

安装依赖

安装 Keras 和 KerasNLP

# 安装最新的 Keras 3。更多信息查看 https://keras.io/getting_started/。

!pip install -q -U keras-nlp
!pip install -q -U keras>=3

导入包

导入 Keras 和 KerasNLP。

import keras
import keras_nlp

选择一个后端

Keras 是一个高级的、多框架的深度学习API,设计上注重简单性和易用性。Keras 3 7允许您选择后端:TensorFlow、JAX或 PyTorch。这三个后端对于本教程都适用。

import os

os.environ["KERAS_BACKEND"] = "jax"  # 或者 "tensorflow" 、 "torch"。

创建模型

KerasNLP 提供了许多流行模型架构8的实现。在本教程中,您将使用 GemmaCausalLM 创建一个模型,这是一个端到端的 Gemma 模型,用于因果语言建模。因果语言模型基于之前的 token 预测下一个 token。

使用from_preset方法创建模型:

gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_2b_en")

from_preset方法通过预设的架构和权重来实例化模型。在上述代码中,字符串"gemma_2b_en"指定了预设的架构:一个拥有20亿参数的 Gemma 模型。(也提供了一个有70亿参数的 Gemma 模型。要在 Colab上运行更大的模型,您需要访问付费计划中提供的高级GPU。或者,您也可以在Kaggle 或 Google Cloud 上进行 Gemma 7B 模型的分布式调优9。)

使用summary方法可以获取更多关于模型的信息:

gemma_lm.summary()
Preprocessor: "gemma_causal_lm_preprocessor"
Tokenizer (type)Vocab #
gemma_tokenizer (GemmaTokenizer)256,000
Model: "gemma_causal_lm"
Layer (type)Output ShapeParam #Connected to
padding_mask (InputLayer)(None, None)0-
token_ids (InputLayer)(None, None)0-
gemma_backbone (GemmaBackbone)(None, None, 2048)2,506,172,416padding_mask[0][0], token_ids[0][0]
token_embedding (ReversibleEmbedding)(None, None, 256000)524,288,000gemma_backbone[0][0]
 Total params: 2,506,172,416 (9.34 GB)
 Trainable params: 2,506,172,416 (9.34 GB)
 Non-trainable params: 0 (0.00 B)

从摘要中可以看到,该模型具有25亿可训练的参数。

生成文本

现在是时候生成一些文本了!模型有一个generate方法,可以基于一个提示来生成文本。可选的max_length参数指定了生成序列的最大长度。

尝试使用提示“What is the meaning of life?”来试一试。

gemma_lm.generate("What is the meaning of life?", max_length=64)
'What is the meaning of life?\n\nThe question is one of the most important questions in the world.\n\nIt’s the question that has been asked by philosophers, theologians, and scientists for centuries.\n\nAnd it’s the question that has been asked by people who are looking for answers to their own lives'

调用generate方法时,可以尝试使用一个不同的提示。

gemma_lm.generate("How does the brain work?", max_length=64)
'How does the brain work?\n\nThe brain is the most complex organ in the human body. It is responsible for controlling all of the body’s functions, including breathing, heart rate, digestion, and more. The brain is also responsible for thinking, feeling, and making decisions.\n\nThe brain is made up'

如果您使用的是 JAX 或 TensorFlow 后端,您会注意到第二次调用generate方法几乎立即返回结果。这是因为每次对给定批量大小和max_lengthgenerate调用都会用XLA编译。第一次运行成本较高,但后续运行会快得多。

您也可以使用列表作为输入来提供批量的提示:

gemma_lm.generate(
    ["What is the meaning of life?",
     "How does the brain work?"],
    max_length=64)
['What is the meaning of life?\n\nThe question is one of the most important questions in the world.\n\nIt’s the question that has been asked by philosophers, theologians, and scientists for centuries.\n\nAnd it’s the question that has been asked by people who are looking for answers to their own lives',
 'How does the brain work?\n\nThe brain is the most complex organ in the human body. It is responsible for controlling all of the body’s functions, including breathing, heart rate, digestion, and more. The brain is also responsible for thinking, feeling, and making decisions.\n\nThe brain is made up']

可选操作:尝试不同的采样器

您可以通过在compile()上设置sampler参数来控制 GemmaCausalLM 的生成策略。默认情况下,将使用“greedy”(贪婪)采样10

作为实验,尝试设置一个“top_k”11策略:

gemma_lm.compile(sampler="top_k")
gemma_lm.generate("What is the meaning of life?", max_length=64)
'What is the meaning of life? That is a question that has been asked for centuries and has yet to be answered. However, there are some people who believe they know the answer and they are willing to share it with the rest of us. In this essay, I will explore the meaning of life from their perspective'

虽然默认的贪婪算法总是选择概率最大的token,但top-K算法会从概率最高的K个token中随机挑选下一个token。

您不必特别指定一个采样器,如果最后的代码片段对您的使用案例没有帮助,可以忽略它。如果您想了解更多可用的采样器,请参阅“采样器”12一节。

来源

本文翻译自 Get started with Gemma using KerasNLP https://www.kaggle.com/code/nilaychauhan/get-started-with-gemma-using-kerasnlp/notebook