AI Edge Torch Generative API for Customized LLMs on Gadget – Uplaza

We’re excited to allow builders to seamlessly deliver new on-device generative AI fashions to edge gadgets. To satisfy that want, we’re asserting the AI Edge Torch Generative API, which permits builders to creator excessive efficiency LLMs in PyTorch for deployment utilizing the TensorFlow Lite (TFLite) runtime. That is the second in a collection of weblog posts protecting Google AI Edge developer releases. The primary submit within the collection launched Google AI Edge Torch, which allows excessive efficiency inference of PyTorch fashions on cellular gadgets utilizing the TFLite runtime.

AI Edge Torch Generative API allows builders to deliver highly effective new capabilities on-device, reminiscent of summarization, content material technology, and extra. We already allow builders to deliver among the hottest LLMs to gadgets utilizing the MediaPipe LLM Inference API. We at the moment are excited to allow builders to deliver any supported mannequin on system with nice efficiency. The preliminary model of AI Edge Torch Generative API affords the next:

  • Straightforward to make use of authoring API for customized transformer assist
  • Nice efficiency on CPU, with GPU and NPU assist coming quickly
  • Totally appropriate with present TFLite deployment flows together with quantization and runtime
  • Works with fashions reminiscent of TinyLlama, Phi-2 and Gemma 2B
  • Suitable with each the TFLite runtime and Mediapipe LLM runtime interfaces with Android, iOS and Internet assist

On this weblog submit we’ll deep dive into efficiency, portability, authoring developer expertise, finish to finish inference pipeline and debug toolchain. Additional documentation and examples can be found right here.


Efficiency

As a part of our work to get among the hottest LLMs figuring out seamlessly by the MediaPipe LLMInference API, our group authored a number of absolutely hand-written transformers with cutting-edge on system efficiency (MediaPipe LLM Inference API weblog). Just a few themes emerged from this work: how you can signify consideration successfully, use of quantization, and the significance of a very good KV Cache illustration. The Generative API makes every of those straightforward to precise (as we’ll see within the subsequent part), whereas nonetheless attaining efficiency that’s >90% of our handwritten variations with far better developer velocity.

The next desk exhibits key benchmarks throughout 3 mannequin examples:

These are benchmarked on massive cores, with 4 CPU threads, and are the quickest CPU implementations of those fashions we’re at the moment conscious of on the gadgets listed.


Authoring Expertise

The core authoring library offers primary constructing blocks for frequent transformer fashions (encoder-only, decoder-only, or encoder-decoder fashion and so on). It lets you both creator a mannequin from scratch, or re-author an present mannequin for improved efficiency. We suggest most customers to re-author, because it requires no coaching/fine-tuning steps. The important thing advantages of the Generative API authoring contains:

  • A set of core transformer constructing blocks optimized for convertibility, efficiency, and platform portability that are straightforward mix-and-match with common PyTorch ops.
  • A straightforward weight re-mapping mechanism.
  • Intuitive quantization APIs.
  • Multi-signature export with prefill, decode or personalized signatures, and works seamlessly with pre-canned MP duties/LLMInference APIs.

For example, right here we showcase how you can re-author TinyLLama(1.1B)’s core performance with round 50 traces of Python with the brand new Generative API.

Step 1: Outline mannequin construction

import torch
import torch.nn as nn

from ai_edge_torch.generative.layers.consideration import TransformerBlock
import ai_edge_torch.generative.layers.attention_utils as attn_utils
import ai_edge_torch.generative.layers.builder as builder
import ai_edge_torch.generative.layers.model_config as cfg


class TinyLLamma(nn.Module):

  def __init__(self, config: cfg.ModelConfig):
    tremendous().__init__()

    self.config = config
    # Assemble mannequin layers.
    self.lm_head = nn.Linear(
        config.embedding_dim, config.vocab_size, bias=config.lm_head_use_bias
    )
    self.tok_embedding = nn.Embedding(
        config.vocab_size, config.embedding_dim, padding_idx=0
    )
    self.transformer_blocks = nn.ModuleList(
        TransformerBlock(config) for _ in vary(config.num_layers)
    )
    self.final_norm = builder.build_norm(
        config.embedding_dim,
        config.final_norm_config,
    )
    self.rope_cache = attn_utils.build_rope_cache(
        dimension=config.kv_cache_max,
        dim=int(config.attn_config.rotary_percentage * config.head_dim),
        base=10_000,
        condense_ratio=1,
        dtype=torch.float32,
        system=torch.system("cpu"),
    )
    self.mask_cache = attn_utils.build_causal_mask_cache(
        dimension=config.kv_cache_max, dtype=torch.float32, system=torch.system("cpu")
    )
    self.config = config

Step 2: Outline mannequin’s ahead operate

@torch.inference_mode
  def ahead(self, idx: torch.Tensor, input_pos: torch.Tensor) -> torch.Tensor:
    B, T = idx.dimension()
    cos, sin = self.rope_cache
    cos = cos.index_select(0, input_pos)
    sin = sin.index_select(0, input_pos)
    masks = self.mask_cache.index_select(2, input_pos)
    masks = masks[:, :, :, : self.config.kv_cache_max]

    # ahead the mannequin itself
    x = self.tok_embedding(idx)  # token embeddings of form (b, t, n_embd)

    for i, block in enumerate(self.transformer_blocks):
      x = block(x, (cos, sin), masks, input_pos)

    x = self.final_norm(x)
    res = self.lm_head(x)  # (b, t, vocab_size)
    return res

Step 3: Map previous mannequin weights

The library lets you map weights simply with the ModelLoader APIs, for instance:

import ai_edge_torch.generative.utilities.loader as loading_utils


# This map will affiliate previous tensor names with the brand new mannequin.
TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
    ff_up_proj="model.layers.{}.mlp.up_proj",
    ff_down_proj="model.layers.{}.mlp.down_proj",
    ff_gate_proj="model.layers.{}.mlp.gate_proj",
    attn_query_proj="model.layers.{}.self_attn.q_proj",
    attn_key_proj="model.layers.{}.self_attn.k_proj",
    attn_value_proj="model.layers.{}.self_attn.v_proj",
    attn_output_proj="model.layers.{}.self_attn.o_proj",
    pre_attn_norm="model.layers.{}.input_layernorm",
    pre_ff_norm="model.layers.{}.post_attention_layernorm",
    embedding="model.embed_tokens",
    final_norm="model.norm",
    lm_head="lm_head",
)

After these steps are completed, you’ll be able to run a number of pattern inputs to confirm numerical correctness (see hyperlink) of the re-authored mannequin. If the numerical test is passing, you’ll be able to proceed to the convert & quantize step.


Conversion & Quantization

With the conversion APIs offered by ai_edge_torch, you’ll be able to leverage the identical API to transform (re-authored) transformer fashions to a extremely optimized TensorFlow Lite mannequin. The conversion course of accommodates the next key steps:

1) Export to StableHLO. The PyTorch mannequin is traced and compiled to a FX Graph with Aten ops by the torch dynamo compiler, then lowered to StableHLO graph by ai_edge_torch.

2) ai_edge_torch runs additional compiler passes on StableHLO, together with op fusion/folding and so on, and generates a extremely performant TFLite flatbuffer (with fused ops for SDPA, KVCache).

Quantization

The core Generative API library additionally offers a set of quantization API which covers frequent LLM quantization recipes. The recipe is handed a further parameter to the ai_edge_torch converter API, which routinely covers quantization. In future releases, we count on to increase the set of quantization modes accessible.

Multi-signature export

We recognized that in actual inference eventualities, LLM fashions must have clearly separated (disaggregated) inference capabilities (prefill, decode) to realize greatest serving efficiency. That is partly based mostly on the statement that prefill/decode could take completely different tensor shapes, prefill is compute-bound whereas decode is reminiscence sure. For giant LLMs, it’s important to keep away from duplicating mannequin weights between prefill/decode. We obtain this utilizing the prevailing multi-signature function in TFLite and ai_edge_torch that allow you to simply outline a number of entry factors for the mannequin as proven under.

def convert_tiny_llama_to_tflite(
    prefill_seq_len: int = 512,
    kv_cache_max_len: int = 1024,
    quantize: bool = True,
):
  pytorch_model = tiny_llama.build_model(kv_cache_max_len=kv_cache_max_len)
  
  # Tensors used to hint the mannequin graph throughout conversion.
  prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.lengthy)
  prefill_input_pos = torch.arange(0, prefill_seq_len)
  decode_token = torch.tensor([[0]], dtype=torch.lengthy)
  decode_input_pos = torch.tensor([0], dtype=torch.int64)

  # Arrange Quantization for mannequin.
  quant_config = quant_recipes.full_linear_int8_dynamic_recipe() if quantize else None
  
  edge_model = (
      ai_edge_torch.signature(
          'prefill', pytorch_model, (prefill_tokens, prefill_input_pos)
      )
      .signature('decode', pytorch_model, (decode_token, decode_input_pos))
      .convert(quant_config=quant_config)
  )
  edge_model.export(f'/tmp/tiny_llama_seq{prefill_seq_len}_kv{kv_cache_max_len}.tflite')

LLM-specific Efficiency optimizations

Throughout our efficiency investigation part, we discovered a number of important facets for enhancing LLM efficiency:

1) Excessive-performant SDPA and KVCache: we discovered that with out sufficient compiler optimizations / fusions, the transformed TFLite mannequin won’t have nice efficiency, given the granular ops in these capabilities. To handle this, we launched high-level operate boundary and StableHLO composite ops

2) Leveraging TFLite’s XNNPack delegate to additional speed up SDPA: it’s important to make sure heavy MatMul/Matrix-vector computations are nicely optimized. The XNNPack library has wonderful efficiency for these primitives throughout a broad vary of cellular CPUs.

3) Avoiding wasteful computations: static form fashions can induce extra compute than is minimally required if fashions have lengthy fastened enter message dimension in prefill stage or massive fastened sequence size in decode stage.

4) Runtime reminiscence consumption. We launched a weight caching / pre-packing mechanism in TFLite’s XNNPack delegate to considerably decrease the height reminiscence utilization.


Deployment

LLM inference sometimes entails many pre/post-processing steps and complex orchestration, e.g. Tokenization, sampling and autoregressive decoding logic. To this finish, we offer each the MediaPipe-based options and a pure C++ inference instance.


Use MediaPipe LLM Inference API

The MediaPipe LLM Inference API is a high-level API which helps LLM Inference utilizing a prompt-in/prompt-out interface. It takes care of all of the complexity of implementing the LLM pipeline underneath the hood, and makes deployment a lot simpler and fluent. To deploy utilizing the MP LLM Inference API, it is advisable make sure you convert fashions utilizing the anticipated prefill and decode signatures, and create a bundle as proven within the code under:

def bundle_tinyllama_q8():
  output_file = "PATH/tinyllama_q8_seq1024_kv1280.task"
  tflite_model = "PATH/tinyllama_prefill_decode_hlfb_quant.tflite"
  tokenizer_model = "PATH/tokenizer.model"
  config = llm_bundler.BundleConfig(
      tflite_model=tflite_model,
      tokenizer_model=tokenizer_model,
      start_token="",
      stop_tokens=[""],
      output_filename=output_file,
      enable_bytes_to_unicode_mapping=False,
  )
  llm_bundler.create_bundle(config)

Pure C++ inference through the TFLite runtime

We additionally offer you an easy-to-use C++ instance (with out MediaPipe dependency) to showcase how you can run an end-to-end textual content technology instance. Builders can use this instance as a place to begin for integrating the exported fashions with their distinctive manufacturing pipelines and necessities, which allows higher customization and suppleness.


Cross-platform assist

For the reason that core inference runtime is in TFLite, the entire pipeline may be simply built-in into your Android (included in Google Play) or iOS apps with none modifications. This may make sure the fashions transformed from the brand new Generative API will probably be instantly deployable by simply including a number of customized op dependencies. In future releases, we’ll deliver GPU assist for each Android & iOS, and goal ML accelerators (TPU, NPU) as nicely.


Tooling

The not too long ago introduced Mannequin Explorer is a great tool for visualizing massive fashions reminiscent of Gemma 2B. Hierarchical viewing and aspect by aspect comparability makes it straightforward to visualise authentic / reauthored / transformed mannequin variations. For extra particulars on this and how one can visualize benchmark information for efficiency tuning, try this weblog submit.

Under is an instance of how we used this when authoring the PyTorch TinyLlama mannequin – displaying the PyTorch export() mannequin alongside the TFLite mannequin. Utilizing Mannequin Explorer, we will simply evaluate how every layer (e.g. RMSNorms, SelfAttention) is expressed.

Sorry, your browser does not assist playback for this video

A side-by-side comparability between TinyLlama PyTorch and transformed TFLite

Abstract & what’s subsequent

The AI Edge Torch Generative API is a robust companion to prebuilt optimized fashions accessible in Mediapipe LLM inference API for builders who need to allow their very own generative AI fashions on system. Within the coming months count on new updates together with net assist, improved quantization and wider compute assist past CPU. We’re additionally occupied with exploring even higher framework integration.

That is an early preview of the library, which is in an experimental stage with the purpose of partaking with the developer neighborhood. Please count on APIs to alter, tough edges, and restricted assist for quantization and fashions. However there’s lots to get began with already in our GitHub repo – bounce in and be happy to share PRs, points, and have requests.

Partly 3 of this collection, we’ll take a deeper have a look at the Mannequin Explorer visualization software that permits builders to visualise, debug and discover fashions.



Acknowledgements

This work is a collaboration throughout a number of purposeful groups at Google. We’d prefer to thank all group members who contributed to this work: Aaron Karp, Advait Jain, Akshat Sharma, Alan Kelly, Andrei Kulik, Arian Afaian, Chun-nien Chan, Chuo-Ling Chang, Cormac Brick, Eric Yang, Frank Barchard, Gunhyun Park, Han Qi, Haoliang Zhang, Ho Ko, Jing Jin, Joe Zoe, Juhyun Lee, Kevin Gleason, Khanh LeViet, Kris Tonthat, Kristen Wright, Lin Chen, Linkun Chen, Lu Wang, Majid Dadashi, Manfei Bai, Mark Sherwood, Matthew Soulanille, Matthias Grundmann, Maxime Brénon, Michael Levesque-Dion, Mig Gerard, Milen Ferev, Mohammadreza Heydary, Na Li, Paul Ruiz, Pauline Sho, Pei Zhang, Ping Yu, Pulkit Bhuwalka, Quentin Khan, Ram Iyengar, Renjie Wu, Rocky Rhodes, Sachin Kotwani, Sandeep Dasgupta, Sebastian Schmidt, Siyuan Liu, Steven Toribio, Suleman Shahid, Tenghui Zhu, T.J. Alumbaugh, Tyler Mullen, Weiyi Wang, Wonjoo Lee, Yi-Chun Kuo, Yishuang Pang, Yu-hui Chen, Zoe Wang, Zichuan Wei.

Share This Article
Leave a comment

Leave a Reply

Your email address will not be published. Required fields are marked *

Exit mobile version