# Ryan's Trick for Distributed Groth16

Please read [Groth16](https://fractalyze.gitbook.io/intro/~/revisions/0AAov1j5GF4J6Ca62R1w/zk/snark/groth16) beforehand!

## Bottlenecks

These are the three most computation-intensive parts that must be distributed:

#### 1. **Witness Reduction**

This is required to compute $$a(X), b(X), c(X)$$:

$$
\sum\_{i=0}^m (A\_{j, i} \cdot z\_i), \quad \sum\_{i=0}^m (B\_{j, i} \cdot z\_i), \quad \sum\_{i=0}^m (C\_{j, i} \cdot z\_i)
$$

#### 2. **FFT / Inverse FFT**

This is needed to compute $$h\_i$$:

$$
\begin{aligned}
a(X) &= \sum\_{j = 0}^{n-1} L\_j(X) \left(\sum\_{i = 0}^{m} A\_{j, i} \cdot z\_i\right),\quad  \bm{a} = \left(a'(\omega^0), \dots, a'(\omega^{n-1})\right) \\
b(X) &= \sum\_{j = 0}^{n-1} L\_j(X) \left(\sum\_{i = 0}^{m} B\_{j, i} \cdot z\_i\right),\quad  \bm{b} = \left(b'(\omega^0), \dots, b'(\omega^{n-1})\right) \\
c(X) &= \sum\_{j = 0}^{n-1} L\_j(X) \left(\sum\_{i = 0}^{m} C\_{j, i} \cdot z\_i\right),\quad  \bm{c} = \left(c'(\omega^0), \dots, c'(\omega^{n-1})\right)
\end{aligned}
$$

#### 3. **MSM**

This is used to compute $$(\[A]\_1, \[B]\_2, \[C]\_1)$$, involving five MSMs, one of which depends on $$h\_i$$:

$$
\begin{aligned}
&\sum\_{i=0}^m z\_i \[a\_i(x)]*1, \quad \sum*{i=0}^m z\_i \[b\_i(x)]*1, \quad \sum*{i=0}^m z\_i \[b\_i(x)]*2, \\
&\sum*{i = \ell + 1}^{m} z\_i \left\[ \frac{\beta  a\_i(x) + \alpha  b\_i(x) + c\_i(x)}{\delta} \right]*1, \\
&\sum*{i = 0}^{n - 2} h\_i \left\[ \frac{{L'}\_{2i + 1}(x)}{\delta} \right]\_1
\end{aligned}
$$

## Partial Witness Reduction + Partial Inverse FFT

### Motivation

The primary motivation for performing **Partial Witness Reduction** and **Partial FFT** is to **partition the data structures** $$A, B, C, \bm{z}$$ across ddd devices. By doing so, each device only stores a fraction of the full data, significantly **reducing memory overhead** and enabling large-scale proving even when the entire witness cannot fit into a single GPU’s memory.

In **Partial Witness Reduction** and **Partial Inverse FFT**, the matrices $$A, B, C$$ are typically sparse, making it difficult to precisely estimate the memory savings from their partitioning. However, the vector $$\bm{z}$$ is dense and evenly split across ddd devices, so its memory overhead is **reduced by a factor of** $$d$$.

### Protocol

Assuming $$m + 1$$ is divisible by $$d$$ (the number of devices), the polynomials can be decomposed as follows:

$$
\begin{aligned}
a(X) &= \sum\_{k = 0}^{d-1} \underbrace{\left(\sum\_{j = 0}^{n-1} L\_j(X) \left(\sum\_{i = k(m + 1) / d}^{(k + 1)(m + 1) / d} A\_{j, i} \cdot z\_i\right)\right)}*{a^{(k)}(X)} \\
b(X) &= \sum*{k = 0}^{d-1} \underbrace{\left(\sum\_{j = 0}^{n-1} L\_j(X) \left(\sum\_{i = k(m + 1) / d}^{(k + 1)(m + 1) / d} B\_{j, i} \cdot z\_i\right)\right)}*{b^{(k)}(X)} \\
c(X) &= \sum*{k = 0}^{d-1} \underbrace{\left(\sum\_{j = 0}^{n-1} L\_j(X) \left(\sum\_{i = k(m + 1) / d}^{(k + 1)(m + 1) / d} C\_{j, i} \cdot z\_i\right)\right)}\_{c^{(k)}(X)}
\end{aligned}
$$

Thus, matrices and vectors can be partitioned into $$d$$ parts:

* $$A \to (A\_k)\_{k=0}^{d-1} \in \left(\mathbb{F}^{n \times \frac{m+1}{d}}\right)^d$$
* $$B \to (B\_k)\_{k=0}^{d-1} \in \left(\mathbb{F}^{n \times \frac{m+1}{d}}\right)^d$$
* $$C \to (C\_k)\_{k=0}^{d-1} \in \left(\mathbb{F}^{n \times \frac{m+1}{d}}\right)^d$$
* $$\bm{z} \to (z\_k)\_{k=0}^{d-1} \in \left(\mathbb{F}^{\frac{m + 1}{d}}\right)^d$$

### Examples

Suppose we are given the following matrices $$A, B, C$$, and vector $$\bm{z}$$:

$$
A = \begin{bmatrix}
A\_{0, 0} & A\_{0, 1} \\
A\_{1, 0} & A\_{1, 1} \\
A\_{2, 0} & A\_{2, 1} \\
A\_{3, 0} & A\_{3, 1} \\
\end{bmatrix}, \quad
B = \begin{bmatrix}
B\_{0, 0} & B\_{0, 1} \\
B\_{1, 0} & B\_{1, 1} \\
B\_{2, 0} & B\_{2, 1} \\
B\_{3, 0} & B\_{3, 1} \\
\end{bmatrix}, \quad
C = \begin{bmatrix}
C\_{0, 0} & C\_{0, 1} \\
C\_{1, 0} & C\_{1, 1} \\
C\_{2, 0} & C\_{2, 1} \\
C\_{3, 0} & C\_{3, 1} \\
\end{bmatrix}, \quad \bm{z} = (z\_0, z\_1)
$$

Then the polynomials $$a(X), b(X), c(X)$$ are constructed as follows:

$$
\begin{align\*}
a(X) &= L\_0(X)(A\_{0, 0}z\_0 + A\_{0, 1}z\_1) + L\_1(X)(A\_{1, 0}z\_0 + A\_{1, 1}z\_1) + L\_2(X)(A\_{2, 0}z\_0 + A\_{2, 1}z\_1) + L\_3(X)(A\_{3, 0}z\_0 + A\_{3, 1}z\_1) \\
&= \underbrace{L\_0(X)(A\_{0, 0}z\_0) + L\_1(X)(A\_{1, 0}z\_0) + L\_2(X)(A\_{2, 0}z\_0) + L\_3(X)(A\_{3, 0}z\_0)}*{a^{(0)(X)}} \\
&\quad + \underbrace{L\_0(X)(A*{0, 1}z\_1) + L\_1(X)(A\_{1, 1}z\_1) + L\_2(X)(A\_{2, 1}z\_1) + L\_3(X)(A\_{3, 1}z\_1)}*{a^{{(1)(X)}}} \\
b(X) &= L\_0(X)(B*{0, 0}z\_0 + B\_{0, 1}z\_1) + L\_1(X)(B\_{1, 0}z\_0 + B\_{1, 1}z\_1) + L\_2(X)(B\_{2, 0}z\_0 + B\_{2, 1}z\_1) + L\_3(X)(B\_{3, 0}z\_0 + B\_{3, 1}z\_1) \\
&= \underbrace{L\_0(X)(B\_{0, 0}z\_0) + L\_1(X)(B\_{1, 0}z\_0) + L\_2(X)(B\_{2, 0}z\_0) + L\_3(X)(B\_{3, 0}z\_0)}*{b^{(0)(X)}} \\
&\quad + \underbrace{L\_0(X)(B*{0, 1}z\_1) + L\_1(X)(B\_{1, 1}z\_1) + L\_2(X)(B\_{2, 1}z\_1) + L\_3(X)(B\_{3, 1}z\_1)}*{b^{{(1)(X)}}} \\
c(X) &= L\_0(X)(C*{0, 0}z\_0 + C\_{0, 1}z\_1) + L\_1(X)(C\_{1, 0}z\_0 + C\_{1, 1}z\_1) + L\_2(X)(C\_{2, 0}z\_0 + C\_{2, 1}z\_1) + L\_3(X)(C\_{3, 0}z\_0 + C\_{3, 1}z\_1) \\
&= \underbrace{L\_0(X)(C\_{0, 0}z\_0) + L\_1(X)(C\_{1, 0}z\_0) + L\_2(X)(C\_{2, 0}z\_0) + L\_3(X)(C\_{3, 0}z\_0)}*{c^{(0)(X)}} \\
&\quad + \underbrace{L\_0(X)(C*{0, 1}z\_1) + L\_1(X)(C\_{1, 1}z\_1) + L\_2(X)(C\_{2, 1}z\_1) + L\_3(X)(C\_{3, 1}z\_1)}\_{c^{{(1)(X)}}} \\
\end{align\*}
$$

## Full FFT

### Motivation

<figure><img src="https://755218234-files.gitbook.io/~/files/v0/b/gitbook-x-prod.appspot.com/o/spaces%2Frwz1ZAZJtK5FHz4Y1esA%2Fuploads%2FP96edvbyR51FEWGyRgpD%2Fimage.png?alt=media&#x26;token=a4a79aee-1006-4983-876a-bbd2265f3e28" alt=""><figcaption></figcaption></figure>

In protocols like DIZK, distributing FFT requires 3 communications per FFT, which leads to substantial overhead. However, according to data from [ICICLE-Snark](https://medium.com/@ingonyama/icicle-snark-the-fastest-groth16-implementation-in-the-world-00901b39a21f), the dominant cost in Groth16 proving is **MSM**, not FFT.\
Therefore, we choose **not to split the FFT**, and instead perform a **Full FFT after reconstructing**. This design reduces communication complexity while focusing optimization efforts on the actual bottleneck.

### Protocol

After the partial inverse FFT, each device performs a [`Reduce`](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/collectives.html#reduce) operation to reconstruct the full polynomials $$a(X), b(X), c(X)$$.&#x20;

We use `Reduce` instead of `AllReduce` for the Full FFT step in order to minimize data communication cost. As a result, a single device is responsible for performing the element-wise multiplication and the forward FFT on the fully reduced polynomials.

### Example

Suppose we have 4 devices, and each device holds polynomials $$a^{(k)}(X), b^{(k)}(X), c^{(k)}(X)$$ for $$k = 0, 1, 2, 3$$. To proceed with proof generation, every device must eventually obtain the full sums:

$$
a(X) = a^{(0)}(X) + a^{(1)}(X) + a^{(2)}(X) + a^{(3)}(X) \\
b(X) = b^{(0)}(X) + b^{(1)}(X) + b^{(2)}(X) + b^{(3)}(X) \\
c(X) = c^{(0)}(X) + c^{(1)}(X) + c^{(2)}(X) + c^{(3)}(X) \\
$$

## Partial MSM

### Motivation

Each MSM can be intuitively split across devices. When performing the 5 MSMs involved in Groth16 proving, both the scalar inputs $$z\_i$$​ and $$h\_i$$ are partitioned across ddd devices, reducing their memory footprint by a factor of $$d$$. As a result, the corresponding **base points** used in each MSM are also reduced proportionally, leading to a significant decrease in **per-device memory usage** and enabling more efficient multi-GPU computation.

### Protocol

Assuming $$n - 1$$ is divisible by $$d$$:

$$
\begin{aligned}
&\sum\_{i=0}^m z\_i \[a\_i(x)]*1 = \sum*{k=0}^{d-1} \left( \sum\_{i=k(m+1)/d}^{(k+1)(m+1)/d} z\_i \[a\_i(x)]*1 \right) \\
&\sum*{i=0}^m z\_i \[b\_i(x)]*1 = \sum*{k=0}^{d-1} \left( \sum\_{i=k(m+1)/d}^{(k+1)(m+1)/d} z\_i \[b\_i(x)]*1 \right) \\
&\sum*{i=0}^m z\_i \[b\_i(x)]*2 = \sum*{k=0}^{d-1} \left( \sum\_{i=k(m+1)/d}^{(k+1)(m+1)/d} z\_i \[b\_i(x)]*2 \right) \\
&\sum*{i = \ell + 1}^{m} z\_i \left\[ \frac{\beta  a\_i(x) + \alpha  b\_i(x) + c\_i(x)}{\delta} \right]*1 = \sum*{k=0}^{d-1} \left( \sum\_{i= \max(k(m+1)/d, \ell + 1)}^{(k+1)(m+1)/d} \cdots \right) \\
&\sum\_{i = 0}^{n - 2} h\_i \left\[ \frac{{L'}*{2i + 1}(x)}{\delta} \right]*1 = \sum*{k=0}^{d-1} \left( \sum*{i=k(n-1)/d}^{(k+1)(n-1)/d} \cdots \right) \\
\end{aligned}
$$

The first four MSMs can be computed independently. The last one requires a precomputed $$h\_i$$, which each device can hold after the [`Send`](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/p2p.html#ncclsend) & [`Recv`](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/p2p.html#ncclsend) operations.

### Example

```cpp
const int field_element_size = 32; // 32 bytes per BN254 field element
const int num_coeffs_per_gpu = ...; // e.g., 2^23 / num_gpus
const size_t chunk_size = num_coeffs_per_gpu * field_element_size;

if (rank == 0) {
    // `full_poly_data` points to the full polynomial coefficients (uint8_t*)
    for (int i = 0; i < num_gpus; ++i) {
        if (i ! = rank) {
            uint8_t* chunk_ptr = full_poly_data + i * chunk_size;
            ncclSend(chunk_ptr, chunk_size, ncclUint8, i, comm, stream);
        }
    }
} else {
    // `my_poly_chunk` will hold this GPU's assigned coefficients
    ncclRecv(my_poly_chunk, chunk_size, ncclUint8, 0, comm, stream);
}
```

## Put Together

In the following expressions, red highlights indicate the parts that must be reduced across devices. The remaining parts can be computed on the host to finalize the Groth16 proof:

$$
\begin{aligned}
\[A]\_1 &= \[\alpha]*1 + \textcolor{red}{\underbrace{\sum*{i=0}^m z\_i \[a\_i(x)]*1}*{A}} + r \[\delta]\_1 \\
\[B]*2 &= \[\beta]*2 + \textcolor{red}{\underbrace{\sum*{i=0}^m z\_i \[b\_i(x)]*2}*{B}} + s \[\delta]*2 \\
\[C]*1 &= \textcolor{red}{\underbrace{\sum*{i = \ell + 1}^{m} z\_i \left\[ \frac{\beta  a\_i(x) + \alpha  b\_i(x) + c\_i(x)}{\delta} \right]*1}*{C\_1} + \underbrace{\sum*{i = 0}^{n - 2} h\_i \left\[ \frac{{L'}*{2i + 1}(x)}{\delta} \right]*1}*{C\_2}} + s\[A]\_1 + r\textcolor{red}{\underbrace{\[B]*1}*{C\_3}} - rs\[\delta]\_1
\end{aligned}
$$

<figure><img src="https://755218234-files.gitbook.io/~/files/v0/b/gitbook-x-prod.appspot.com/o/spaces%2Frwz1ZAZJtK5FHz4Y1esA%2Fuploads%2FJ6FWOOtRDk7FljOMOwVH%2FScreenshot%202025-04-22%20at%205.57.37%E2%80%AFPM.png?alt=media&#x26;token=3bd80134-b3d9-4985-8e96-ea49d2fbb002" alt=""><figcaption></figcaption></figure>

In protocols like [DIZK](https://fractalyze.gitbook.io/intro/~/revisions/0AAov1j5GF4J6Ca62R1w/zk/distributed-zk/dizk), distributing FFT requires many communications per FFT. In contrast, our protocol incurs less communications only during the `Reduce`, `Send` and `Recv` steps, while the subsequent `AllReduce` step involves only a constant number of group elements. As a result, the total communication cost is significantly lower than that of DIZK. Moreover, this distribution strategy can be efficiently implemented using [SPMD](https://arxiv.org/abs/2105.04663).

| Communication Step | Communication Cost                             |
| ------------------ | ---------------------------------------------- |
| `Reduce`           | $$3 \times (d - 1) \times n$$                  |
| `Send` & `Recv`    | $$(d - 1) \times \frac{n}{d}$$                 |
| `AllReduce`        | $$O(d)$$                                       |
| DIZK's `AllToAll`  | $$3 \times 3 \times 2\times (d - 1) \times n$$ |

## Analysis: RISC0 Groth16 proof

Currently, the [stark\_verify.circom](https://github.com/risc0/risc0/blob/4ac16240c8e76583ee6a68a3b71a21c82f37313c/groth16_proof/groth16/stark_verify.circom) used in RISC0 has the following characteristics:

```bash
> circom --r1cs stark_verify.circom 
template instances: 349
non-linear constraints: 5676573
linear constraints: 0
public inputs: 0
private inputs: 25749 (22333 belong to witness)
public outputs: 5
wires: 5635930
labels: 10298490
Written successfully: ./stark_verify.r1cs
Everything went okay
```

The number of rows in the $$A, B, C$$ matrices is **5,676,573**, which means an SRS of size $$2^{23}$$ is required. The number of columns is **5,635,930**.

## Reduce Analysis

We now analyze the cost of the first and most expensive communication step in our distributed system: the `Reduce` operation.

Assume 4 GPUs each compute and store partial results of the polynomials $$a^{(k)}(X), b^{(k)}(X), c^{(k)}(X)$$, and together, they must aggregate them into full polynomials of size $$2^{23}$$ over the BN254 field. The total volume of data involved is:

* **3 polynomials per GPU × 256 MB each to be sent to the leader GPU = 768 MB per GPU**
* **Total communication volume on receiver side of leader GPU = 3 GPUs × 768 MB = 2304 MB**

The goal of `Reduce` is to **sum these partial polynomials in the leader device**, so that 1 leader GPU retains the complete 768 MB result (i.e., full $$a(X), b(X), c(X)$$).

### Computation Cost

Field additions in BN254 are extremely lightweight on GPUs. Each operation for $$a\_i(X), b\_i(X), c\_i(X)$$ takes less than 1 ms and is negligible in the overall runtime.

### Communication Cost

#### **PCIe Transfer Speed**

| Item                           | Value                                       |
| ------------------------------ | ------------------------------------------- |
| Interface                      | PCIe Gen4 x16                               |
| Theoretical Bandwidth          | 32 GB/s                                     |
| Measured Bandwidth (via NCCL)  | 24 \~ 28 GB/s                               |
| Receive time by the leader GPU | 2304 MB / (24 \~ 28) GB/s ≈ **82 \~ 96 ms** |
| Estimated with NCCL overhead   | **92 \~ 106 ms**                            |

<figure><img src="https://755218234-files.gitbook.io/~/files/v0/b/gitbook-x-prod.appspot.com/o/spaces%2Frwz1ZAZJtK5FHz4Y1esA%2Fuploads%2Fuc99ePizDPNFmCDDS3Uu%2Fimage.png?alt=media&#x26;token=8a83e0b5-7b68-41cd-8807-3691293b3852" alt=""><figcaption><p>Source: <a href="https://www.marvell.com/content/dam/marvell/en/blogs/2024/01/PCIe-Gen6-IO-Bandwidth.png">https://www.marvell.com/content/dam/marvell/en/blogs/2024/01/PCIe-Gen6-IO-Bandwidth.png</a></p></figcaption></figure>

We chose **PCIe Gen4 x16** because it offers a well-balanced trade-off between **performance, cost, and ecosystem stability**. Gen4 provides up to **32 GB/s bidirectional bandwidth**, which is sufficient for most real-world proving workloads, especially when combined with smart overlapping strategies between computation and communication.

While **bandwidth can become a performance bottleneck** in some high-throughput scenarios, upgrading to newer generations like **PCIe Gen5 or Gen6** introduces significant trade-offs in **cost, complexity, and platform requirements**. For now, Gen4 remains the **most practical and widely supported option**, but we remain open to adopting higher PCIe generations if communication overhead proves to be a critical limiting factor.

#### **Performance Summary**

| Task                          | Time             |
| ----------------------------- | ---------------- |
| Field addition                | < 3 ms           |
| PCIe communication (`Reduce`) | **92 \~ 106 ms** |
| **Total Execution Time**      | **95 \~ 109 ms** |

According to ICICLE-Snark, a polynomial of degree $$2^{23}$$ takes approximately **774 ms** for MSM. With four devices, each handling a degree $$2^{21}$$ polynomial, the per-device MSM time is around **193 ms**. If we overlap the `Reduce` step with two of these MSMs, the communication overhead can be effectively hidden.

## Estimation

Let’s perform a simple estimation based on the data from [**ICICLE-Snark**](https://medium.com/@ingonyama/icicle-snark-the-fastest-groth16-implementation-in-the-world-00901b39a21f).

### **Assumptions**

1. $$n$$ and $$m$$ are $$2^{22}$$
2. The runtime for MSM over $$\mathbb{G}\_2$$ is the same as that for $$\mathbb{G}\_1$$.
3. The runtime for $$\mathbb{G}\_1$$ MSM scales linearly with the degree.
4. The total proving time for Groth16 is the sum of the time taken for 5 MSMs (387 ms), 1 FFT ( 10 ms), and 1 IFFT (10 ms).
5. The time for `Reduce`, `Recv` and `Send` is negligible.

### Computation

If everything is computed **serially**, the total time is:

$$
5 \times 387 + 10 + 10 = 1955\ \text{ms}
$$

If we instead use the proposed scheme across **4 GPUs**, the time becomes:

$$
5 \times (387 / 4) + 10 + 10 \approx 504\ \text{ms}
$$

This shows that the proving time is reduced by approximately a **factor of 4**.

### Input Size

We do not include $$A, B, C$$ in our input size estimation, as their sparsity makes it difficult to quantify preciesly. However, they will also contribute to reducing the overall memory requirement.

If everything is computed on a single device, the input size is around $$1664$$ MB. Each component will consume memory as follows:

* Witness vector $$\bm{z}$$: $$2^{22} \times 32$$ B $$= 128$$ MB
* MSM base point $$\left(\[a\_i(x)]*1\right)*{i = 0}^{m}$$: $$2^{22} \times 64$$ B $$= 256$$ MB
* MSM base point $$\left(\[b\_i(x)]*1\right)*{i = 0}^{m}$$: $$2^{22} \times 64$$ B $$= 256$$ MB
* MSM base point $$\left(\[b\_i(x)]*2\right)*{i = 0}^{m}$$: $$2^{22} \times 128$$ B $$= 512$$ MB
* MSM base point $$\left(\left\[ \frac{\beta  a\_i(x) + \alpha  b\_i(x) + c\_i(x)}{\delta} \right]*1\right)*{i = \ell + 1}^{m}$$: $$2^{22} \times 64$$ B $$\approx 256$$ MB
* MSM base point $$\left(\left\[ \frac{{L'}\_{2i + 1}(x)}{\delta} \right]*1\right)*{i = 0}^{n-2}$$: $$2^{22} \times 64$$ B $$\approx 256$$ MB&#x20;

If we instead use the proposed scheme across **4 GPUs**, the input size becomes around $$416$$ MB.

### Runtime Memory Size

{% hint style="warning" %}
Here, we estimate only the memory required for MSM itself, **excluding the additional memory needed for buckets in the** [**Pippenger**](https://fractalyze.gitbook.io/intro/~/revisions/0AAov1j5GF4J6Ca62R1w/primitives/abstract-algebra/elliptic-curve/msm/pippengers-algorithm) **algorithm**.
{% endhint %}

If everything is computed on a single device, and intermediate memory is released immediately after use, the main bottleneck becomes the **MSM in** $$\mathbb{G}\_2$$, which requires $$128 + 512 = 640$$ MB of memory.

However, with the proposed scheme using **4 GPUs**, this memory requirement is reduced to $$32 + 128 = 160$$ MB. The **Full FFT**, including twiddle factors, requires an additional $$256$$ MB. Therefore, under this setup, **the total memory required per device is approximately** $$256$$ MB.

> Written by [Ryan Kim](https://app.gitbook.com/u/cPk8gft4tSd0Obi6ARBfoQ16SqG2 "mention") of Fractalyze
