Keras Hub,您的一站式预训练模型库
作者 / 软件工程师 Divyashree Sreepathihalli 和 Google AI 开发技术推广工程师 Luciano Martins
深度学习领域正在迅速发展,在处理各种类型的任务中,预训练模型变得越来越重要。Keras 以其用户友好型 API 和对易用性的重视而闻名,始终处于这一动向的前沿。Keras 拥有专用的内容库,如用于文本模型的 KerasNLP 和用于计算机视觉模型的 KerasCV。
然而,随着模型使各模态之间的界限越来越模糊 (想象一下强大的聊天 LLM 具有图像输入功能或是在视觉任务中利用文本编码器),维持这些独立的领域变得不那么实际。NLP 和 CV 之间的区别可能会阻碍真正多模态模型的发展和部署,从而导致冗余的工作和碎片化的用户体验。
为了解决这个问题,我们很高兴地宣布 Keras 生态系统迎来重大变革:隆重推出 KerasHub,一个统一、全面的预训练模型库,简化了对前沿 NLP 和 CV 架构的访问。KerasHub 是一个中央存储库,您可以在稳定且熟悉的 Keras 框架内无缝探索和使用最先进的模型,例如用于文本分析的 BERT 以及用于图像分类的 EfficientNet。
🔗 KerasHub
https://keras.io/keras_hub/
统一的开发者体验
这种统一不仅简化了对模型的探索和使用,还有助于打造更具凝聚力的生态系统。通过 KerasHub,您可以利用高级功能,例如轻松的发布和共享模型、用于优化资源效率的 LoRA 微调、用于优化性能的量化,以及用于处理大规模数据集的强大多主机训练,所有这些功能都适用于各种模态。这标志着在普及强大的 AI 工具以及加速开发创新型多模态应用方面迈出了重要一步。
KerasHub 入门步骤
首先在您的系统上安装 KerasHub,您可以在其中探索大量现成的模型和主流架构的不同实现方式。然后,您就可以轻松地将这些预训练的模型加载并整合到自己的项目中,并根据您的具体需求对其进行微调,以获得最佳性能。
🔗 现成的模型
https://keras.io/api/keras_hub/models/
安装 KerasHub
要安装带有 Keras 3 的 KerasHub 最新版本,只需运行以下代码:
$ pip install --upgrade keras-hub
现在,您可以开始探索可用的模型。使用 Keras 3 开始工作的标准环境设置在开始使用 KerasHub 时并不需要任何改变:
import os
# Define the Keras 3 backend you want to use - "jax", "tensorflow" or "torch"
os.environ["KERAS_BACKEND"] = "jax"
# Import Keras 3 and KerasHub modules
import keras
import keras_hub
通过 KerasHub 使用
计算机视觉和自然语言模型
现在,您可以通过 KerasHub 访问和使用 Keras 3 生态系统中的模型。以下是一些示例:
Gemma
Gemma 是由 Google 开发的一系列前沿且易于使用的开放模型。依托于与 Gemini 模型相同的研究和技术,Gemma 的基础模型在各种文本生成任务中表现出色,包括回答问题、总结信息以及进行逻辑推理。此外,您还可以针对特定需求自定义模型。
🔗 Gemma
https://ai.google.dev/gemma/docs/base
在此示例中,您可以使用 Keras 和 KerasHub 加载并开始使用 Gemma 2 2B 参数生成内容。有关 Gemma 变体的更多详细信息,请查看 Kaggle 上的 Gemma 模型卡。
# Load Gemma 2 2B preset from Kaggle models
gemma_lm = keras_hub.models.GemmaCausalLM.from_preset("gemma_2b_en")
# Start generating contents with Gemma 2 2B
gemma_lm.generate("Keras is a", max_length=32)
🔗 Gemma 模型卡
https://www.kaggle.com/models/google/gemma/
PaliGemma
PaliGemma 是一款紧凑型的开放模型,可以理解图像和文本。PaliGemma 从 PaLI-3 中汲取灵感,以 SigLIP 视觉模型和 Gemma 语言模型等开源组件为基础,可以针对有关图像的问题提供详细且富有洞察力的答案。因此,该模型可以更深入地了解视觉内容,从而实现诸多功能,例如为图像和短视频生成描述、识别对象甚至理解图像中的文本。
import os
# Define the Keras 3 backend you want to use - "jax", "tensorflow" or "torch"
os.environ["KERAS_BACKEND"] = "jax"
# Import Keras 3 and KerasHub modules
import keras
import keras_hub
from keras.utils import get_file, load_img, img_to_array
# Import PaliGemma 3B fine tuned with 224x224 images
pali_gemma_lm = keras_hub.models.PaliGemmaCausalLM.from_preset(
"pali_gemma_3b_mix_224"
)
# Download a test image and prepare it for usage with KerasHub
url = 'https://storage.googleapis.com/keras-cv/models/paligemma/cow_beach_1.png'
img_path = get_file(origin=url)
img = img_to_array(load_img(image_path))
# Create the prompt with the question about the image
prompt = 'answer where is the cow standing?'
# Generate the contents with PaliGemma
output = pali_gemma_lm.generate(
inputs={
"images": img,
"prompts": prompt,
}
)
🔗 PaliGemma
https://ai.google.dev/gemma/docs/paligemma
🔗 PaLI-3
https://arxiv.org/abs/2310.09199
🔗 SigLIP 视觉模型
https://arxiv.org/abs/2303.15343
🔗 Gemma 语言模型
https://arxiv.org/abs/2403.08295
有关 Keras 3 上可用的预训练模型的更多详细信息,请在 Kaggle 上查看 Keras 中的模型列表。
🔗 Kaggle 上查看 Keras 中的模型列表
https://www.kaggle.com/organizations/keras/models
Stability.ai Stable Diffusion 3
您也可以使用计算机视觉模型。例如,您可以通过 KerasHub 使用 stability.ai Stable Diffusion 3:
from PIL import Image
from keras.utils import array_to_img
from keras_hub.models import StableDiffusion3TextToImage
text_to_image = StableDiffusion3TextToImage.from_preset(
"stable_diffusion_3_medium",
height=1024,
width=1024,
dtype="float16",
)
# Generate images with SD3
image = text_to_image.generate(
"photograph of an astronaut riding a horse, detailed, 8k",
)
# Display the generated image
img = array_to_img(image)
img
🔗 Stable Diffusion 3
https://stability.ai/news/stable-diffusion-3
有关 Keras 3 上可用的预训练计算机视觉模型的更多详细信息,请查看 Keras 中的模型列表。
🔗 Keras 中的模型列表
https://keras.io/api/keras_hub/models/
对于 KerasNLP 开发者而言,
有哪些变化?
从 KerasNLP 到 KerasHub 的过渡是一个简单的过程。只需要将 import 语句从 keras_nlp 更新为 keras_hub。
示例:以前,您可能需要导入 keras_nlp 才能使用 BERT 模型,如下所示
import keras_nlp
# Load a BERT model
classifier = keras_nlp.models.BertClassifier.from_preset(
"bert_base_en_uncased",
num_classes=2,
)
现在,您只需调整 import,即可使用 KerasHub:
import keras_hub
# Load a BERT model
classifier = keras_hub.models.BertClassifier.from_preset(
"bert_base_en_uncased",
num_classes=2,
)
对于 KerasCV 开发者而言,
有哪些变化?
如果您当前是 KerasCV 用户,更新到 KerasHub 能够为您带来以下好处:
简化模型加载:KerasHub 为加载模型提供了统一的 API,如果您同时使用 KerasCV 和 KerasNLP,这可以简化您的代码。
框架灵活性:如果您有兴趣探索 JAX 或 PyTorch 等不同框架,KerasHub 可以让您更轻松地将这些框架与 KerasCV 和 KerasNLP 模型结合起来使用。
集中式存储库:借助 KerasHub 的统一模型存储库,您可以更轻松地查找和访问模型,未来还可以在其中添加新架构。
如何使我的代码适配 KerasHub?
模型
目前,我们正在将 KerasCV 模型迁移到 KerasHub。虽然大多数模型已经可用,但有些仍在迁移中。请注意,Centerpillar 模型不会被迁移。您应该能够在 KerasHub 使用任何视觉模型,方法如下:
import keras_hub
# Load a model using preset
Model = keras_hub.models.<model_name>.from_preset('preset_name`)
# or load a custom model by specifying the backbone and preprocessor
Model = keras_hub.models.<model_name>(backbone=backbone, preprocessor=preprocessor)
🔗 Centerpillar
https://www.kaggle.com/models/keras/centerpillar
KerasHub 为 KerasCV 开发者带来了激动人心的新功能,提供了更高的灵活性和扩展能力。其中包括:
内置预处理
每个模型都配备了一个定制的预处理器,用于处理包括调整大小、重新缩放等常规任务,从而简化您的工作流程。
在此之前,预处理输入是在向模型提供输入之前手动执行的。
# Preprocess inputs for example
def preprocess_inputs(image, label):
# Resize rescale or do more preprocessing on inputs
return preprocessed_inputs
backbone = keras_cv.models.ResNet50V2Backbone.from_preset(
"resnet50_v2_imagenet",
)
model = keras_cv.models.ImageClassifier(
backbone=backbone,
num_classes=4,
)
output = model(preprocessed_input)
目前,任务模型的预处理已集成到现成的预设中。预处理器会对输入进行预处理,对样本图像进行大小调整和重新缩放。预处理器是任务模型的内在组件。尽管如此,开发者还是可以选择使用个性化的预处理器。
classifier = keras_hub.models.ImageClassifier.from_preset('resnet_18_imagenet')
classifier.predict(inputs)
损失函数
与增强层类似,以前 KerasCV 中的损失函数现在可在 Keras 中通过 keras.losses.<loss_function> 使用。例如,如果您当前正在使用 FocalLoss 函数:
import keras
import keras_cv
keras_cv.losses.FocalLoss(
alpha=0.25, gamma=2, from_logits=False, label_smoothing=0, **kwargs
)
🔗 FocalLoss 函数
https://keras.io/api/keras_cv/losses/focal_loss/
您只需调整损失函数定义代码,使用 keras.losses 而不是 keras_cv.losses:
import keras
keras.losses.FocalLoss(
alpha=0.25, gamma=2, from_logits=False, label_smoothing=0, **kwargs
)
开始使用 KerasHub
立即体验 KerasHub 的世界:
🔗 查看官方文档,开始使用 KerasHub
https://keras.io/keras_hub/
🔗 查看 KerasHub 入门指南
https://keras.io/guides/keras_hub/
🔗 试用预训练模型
https://keras.io/api/keras_hub/models/
🔗 探索源代码,期待看到您做出的贡献
https://github.com/keras-team/keras-hub/
🔗 在 Kaggle 上深入了解 Keras
https://www.kaggle.com/organizations/keras
加入 Keras 社区,释放统一、便利且高效的深度学习模型的力量。AI 的未来发展方向是多模态 AI,而 KerasHub 便是您打开这一未来的钥匙!欢迎您持续关注 "Android 开发者" 微信公众号,及时了解更多开发技术和产品更新等资讯动态!
推荐阅读
如页面未加载,请刷新重试
点击屏末 | 阅读原文 | 即刻查看 KerasHub 入门指南