Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

o1vm/mips: use batch_inversion for the witness generation #2813

Merged
merged 15 commits into from
Nov 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 36 additions & 7 deletions o1vm/src/interpreters/mips/column.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@ use kimchi_msm::{
use std::ops::{Index, IndexMut};
use strum::EnumCount;

use super::{ITypeInstruction, JTypeInstruction, RTypeInstruction};
pub use super::{
witness::SCRATCH_SIZE_INVERSE, ITypeInstruction, JTypeInstruction, RTypeInstruction,
};

/// The number of hashes performed so far in the block
pub(crate) const MIPS_HASH_COUNTER_OFF: usize = 80;
Expand All @@ -35,7 +37,7 @@ pub(crate) const MIPS_CHUNK_BYTES_LEN: usize = 4;
pub(crate) const MIPS_PREIMAGE_KEY: usize = 97;

/// The number of columns used for relation witness in the MIPS circuit
pub const N_MIPS_REL_COLS: usize = SCRATCH_SIZE + 2;
pub const N_MIPS_REL_COLS: usize = SCRATCH_SIZE + SCRATCH_SIZE_INVERSE + 2;

/// The number of witness columns used to store the instruction selectors.
pub const N_MIPS_SEL_COLS: usize =
Expand All @@ -50,6 +52,9 @@ pub const N_MIPS_COLS: usize = N_MIPS_REL_COLS + N_MIPS_SEL_COLS;
pub enum ColumnAlias {
// Can be seen as the abstract indexed variable X_{i}
ScratchState(usize),
// A column whose value needs to be inverted in the final witness.
// We're keeping a separate column to perform a batch inversion at the end.
ScratchStateInverse(usize),
InstructionCounter,
Selector(usize),
}
Expand All @@ -66,8 +71,12 @@ impl From<ColumnAlias> for usize {
assert!(i < SCRATCH_SIZE);
i
}
ColumnAlias::InstructionCounter => SCRATCH_SIZE,
ColumnAlias::Selector(s) => SCRATCH_SIZE + 1 + s,
ColumnAlias::ScratchStateInverse(i) => {
assert!(i < SCRATCH_SIZE_INVERSE);
SCRATCH_SIZE + i
}
ColumnAlias::InstructionCounter => SCRATCH_SIZE + SCRATCH_SIZE_INVERSE,
ColumnAlias::Selector(s) => SCRATCH_SIZE + SCRATCH_SIZE_INVERSE + 1 + s,
}
}
}
Expand Down Expand Up @@ -132,16 +141,36 @@ impl<T: Clone> IndexMut<ColumnAlias> for MIPSWitness<T> {

impl ColumnIndexer for ColumnAlias {
const N_COL: usize = N_MIPS_COLS;

fn to_column(self) -> Column {
match self {
Self::ScratchState(ss) => {
assert!(ss < SCRATCH_SIZE);
assert!(
ss < SCRATCH_SIZE,
"The maximum index is {}, got {}",
SCRATCH_SIZE,
ss
);
Column::Relation(ss)
}
Self::InstructionCounter => Column::Relation(SCRATCH_SIZE),
Self::ScratchStateInverse(ss) => {
assert!(
ss < SCRATCH_SIZE_INVERSE,
"The maximum index is {}, got {}",
SCRATCH_SIZE_INVERSE,
ss
);
Column::Relation(SCRATCH_SIZE + ss)
}
Self::InstructionCounter => Column::Relation(SCRATCH_SIZE + SCRATCH_SIZE_INVERSE),
// TODO: what happens with error? It does not have a corresponding alias
Self::Selector(s) => {
assert!(s < N_MIPS_SEL_COLS);
assert!(
s < N_MIPS_SEL_COLS,
"The maximum index is {}, got {}",
N_MIPS_SEL_COLS,
s
);
Column::DynamicSelector(s)
}
}
Expand Down
11 changes: 10 additions & 1 deletion o1vm/src/interpreters/mips/constraints.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ use super::column::N_MIPS_SEL_COLS;
/// The environment keeping the constraints between the different polynomials
pub struct Env<Fp> {
scratch_state_idx: usize,
scratch_state_idx_inverse: usize,
/// A list of constraints, which are multi-variate polynomials over a field,
/// represented using the expression framework of `kimchi`.
constraints: Vec<E<Fp>>,
Expand All @@ -37,6 +38,7 @@ impl<Fp: Field> Default for Env<Fp> {
fn default() -> Self {
Self {
scratch_state_idx: 0,
scratch_state_idx_inverse: 0,
constraints: Vec::new(),
lookups: Vec::new(),
selector: None,
Expand All @@ -62,6 +64,12 @@ impl<Fp: Field> InterpreterEnv for Env<Fp> {
MIPSColumn::ScratchState(scratch_idx)
}

fn alloc_scratch_inverse(&mut self) -> Self::Position {
let scratch_idx = self.scratch_state_idx_inverse;
self.scratch_state_idx_inverse += 1;
MIPSColumn::ScratchStateInverse(scratch_idx)
}

type Variable = E<Fp>;

fn variable(&self, column: Self::Position) -> Self::Variable {
Expand Down Expand Up @@ -219,7 +227,7 @@ impl<Fp: Field> InterpreterEnv for Env<Fp> {
unsafe { self.test_zero(x, pos) }
};
let x_inv_or_zero = {
let pos = self.alloc_scratch();
let pos = self.alloc_scratch_inverse();
self.variable(pos)
};
// If x = 0, then res = 1 and x_inv_or_zero = 0
Expand Down Expand Up @@ -623,6 +631,7 @@ impl<Fp: Field> InterpreterEnv for Env<Fp> {

fn reset(&mut self) {
self.scratch_state_idx = 0;
self.scratch_state_idx_inverse = 0;
self.constraints.clear();
self.lookups.clear();
self.selector = None;
Expand Down
2 changes: 2 additions & 0 deletions o1vm/src/interpreters/mips/interpreter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,8 @@ pub trait InterpreterEnv {
/// [crate::interpreters::mips::witness::SCRATCH_SIZE]
fn alloc_scratch(&mut self) -> Self::Position;

fn alloc_scratch_inverse(&mut self) -> Self::Position;

type Variable: Clone
+ std::ops::Add<Self::Variable, Output = Self::Variable>
+ std::ops::Sub<Self::Variable, Output = Self::Variable>
Expand Down
1 change: 1 addition & 0 deletions o1vm/src/interpreters/mips/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ mod rtype {
// that condition would generate an infinite loop instead)
while dummy_env.registers.preimage_offset < total_length {
dummy_env.reset_scratch_state();
dummy_env.reset_scratch_state_inverse();

// Set maximum number of bytes to read in this call
dummy_env.registers[6] = rng.gen_range(1..=4);
Expand Down
4 changes: 4 additions & 0 deletions o1vm/src/interpreters/mips/tests_helpers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ use std::{fs, path::PathBuf};
// FIXME: we should parametrize the tests with different fields.
use ark_bn254::Fr as Fp;

use super::witness::SCRATCH_SIZE_INVERSE;

const PAGE_INDEX_EXECUTABLE_MEMORY: u32 = 1;

pub(crate) struct OnDiskPreImageOracle;
Expand Down Expand Up @@ -87,7 +89,9 @@ where
registers: Registers::default(),
registers_write_index: Registers::default(),
scratch_state_idx: 0,
scratch_state_idx_inverse: 0,
scratch_state: [Fp::from(0); SCRATCH_SIZE],
scratch_state_inverse: [Fp::from(0); SCRATCH_SIZE_INVERSE],
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In a future PR, as suggested by @mrmr1993, we can try to use BigInt/BigUInt directly to avoid computation in Montgomery representation. I do not know if it would make it better.

selector: crate::interpreters::mips::column::N_MIPS_SEL_COLS,
halt: false,
// Keccak related
Expand Down
54 changes: 37 additions & 17 deletions o1vm/src/interpreters/mips/witness.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,17 @@ pub const NUM_INSTRUCTION_LOOKUP_TERMS: usize = 5;
pub const NUM_LOOKUP_TERMS: usize =
NUM_GLOBAL_LOOKUP_TERMS + NUM_DECODING_LOOKUP_TERMS + NUM_INSTRUCTION_LOOKUP_TERMS;
// TODO: Delete and use a vector instead
// FIXME: since the introduction of the scratch size inverse, the value below
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is fixed in #2815.

// can be decreased. It implies to change the offsets defined in [column]. At
// the moment, it incurs an overhead we could avoid as some columns are zeroes.
// MIPS + hash_counter + byte_counter + eof + num_bytes_read + chunk + bytes
// + length + has_n_bytes + chunk_bytes + preimage
pub const SCRATCH_SIZE: usize = 98;

/// Number of columns used by the MIPS interpreter to keep values to be
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To verify it was the minimal value, I did try to use 11, and it failed.
From there, I run for some millions instructions the op-program, and it never failed.

/// inverted.
pub const SCRATCH_SIZE_INVERSE: usize = 12;

#[derive(Clone, Default)]
pub struct SyscallEnv {
pub last_hint: Option<Vec<u8>>,
Expand Down Expand Up @@ -81,7 +88,9 @@ pub struct Env<Fp, PreImageOracle: PreImageOracleT> {
pub registers: Registers<u32>,
pub registers_write_index: Registers<u64>,
pub scratch_state_idx: usize,
pub scratch_state_idx_inverse: usize,
pub scratch_state: [Fp; SCRATCH_SIZE],
pub scratch_state_inverse: [Fp; SCRATCH_SIZE_INVERSE],
pub halt: bool,
pub syscall_env: SyscallEnv,
pub selector: usize,
Expand All @@ -106,6 +115,12 @@ impl<Fp: Field, PreImageOracle: PreImageOracleT> InterpreterEnv for Env<Fp, PreI
Column::ScratchState(scratch_idx)
}

fn alloc_scratch_inverse(&mut self) -> Self::Position {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems like this function for the mips witness is the same as the one in the mips constraints. I wonder if it could be generically defined in the interpreter itself, but probably not directly given that it uses the fact that it accesses a self.scratch_state_idx_inverse. Maybe environments can be defined to implement a trait that returns it, but not sure if it is worth the code duplication.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do think it is fine at the moment. Agree it is a bit "annoying"/"ugly" to have duplicated code. But lgtm for the moment.

let scratch_idx = self.scratch_state_idx_inverse;
self.scratch_state_idx_inverse += 1;
Column::ScratchStateInverse(scratch_idx)
}

type Variable = u64;

fn variable(&self, _column: Self::Position) -> Self::Variable {
Expand Down Expand Up @@ -314,17 +329,17 @@ impl<Fp: Field, PreImageOracle: PreImageOracleT> InterpreterEnv for Env<Fp, PreI

fn is_zero(&mut self, x: &Self::Variable) -> Self::Variable {
// write the result
let pos = self.alloc_scratch();
let res = if *x == 0 { 1 } else { 0 };
self.write_column(pos, res);
let res = {
let pos = self.alloc_scratch();
unsafe { self.test_zero(x, pos) }
};
// write the non deterministic advice inv_or_zero
let pos = self.alloc_scratch();
let inv_or_zero = if *x == 0 {
Fp::zero()
let pos = self.alloc_scratch_inverse();
if *x == 0 {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure why there's this if-else distinction if it always writes in the inverse scratch state. Is it to avoid a conversion in case Fp::zero() is faster than Fp:from(0)? Either way I suppose the batch inversion algorithm will just ignore zeros, I believe I checked that in the arkworks code.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No reason. Seems to be an ugly code 😅. For a follow-up.

self.write_field_column(pos, Fp::zero());
} else {
Fp::inverse(&Fp::from(*x)).unwrap()
self.write_field_column(pos, Fp::from(*x));
};
self.write_field_column(pos, inv_or_zero);
// return the result
res
}
Expand All @@ -339,15 +354,11 @@ impl<Fp: Field, PreImageOracle: PreImageOracleT> InterpreterEnv for Env<Fp, PreI
self.write_column(pos, is_zero);
is_zero
};
let _to_zero_test_inv_or_zero = {
let pos = self.alloc_scratch();
let inv_or_zero = if to_zero_test == Fp::zero() {
Fp::zero()
} else {
Fp::inverse(&to_zero_test).unwrap()
};
self.write_field_column(pos, inv_or_zero);
1 // Placeholder value
let pos = self.alloc_scratch_inverse();
dannywillems marked this conversation as resolved.
Show resolved Hide resolved
if to_zero_test == Fp::zero() {
self.write_field_column(pos, Fp::zero());
} else {
self.write_field_column(pos, to_zero_test);
};
res
}
Expand Down Expand Up @@ -878,7 +889,9 @@ impl<Fp: Field, PreImageOracle: PreImageOracleT> Env<Fp, PreImageOracle> {
registers: initial_registers.clone(),
registers_write_index: Registers::default(),
scratch_state_idx: 0,
scratch_state_idx_inverse: 0,
scratch_state: fresh_scratch_state(),
scratch_state_inverse: fresh_scratch_state(),
halt: state.exited,
syscall_env,
selector,
Expand All @@ -897,13 +910,19 @@ impl<Fp: Field, PreImageOracle: PreImageOracleT> Env<Fp, PreImageOracle> {
self.selector = N_MIPS_SEL_COLS;
}

pub fn reset_scratch_state_inverse(&mut self) {
self.scratch_state_idx_inverse = 0;
self.scratch_state_inverse = fresh_scratch_state();
}

pub fn write_column(&mut self, column: Column, value: u64) {
self.write_field_column(column, value.into())
}

pub fn write_field_column(&mut self, column: Column, value: Fp) {
match column {
Column::ScratchState(idx) => self.scratch_state[idx] = value,
Column::ScratchStateInverse(idx) => self.scratch_state_inverse[idx] = value,
Column::InstructionCounter => panic!("Cannot overwrite the column {:?}", column),
Column::Selector(s) => self.selector = s,
}
Expand Down Expand Up @@ -1138,6 +1157,7 @@ impl<Fp: Field, PreImageOracle: PreImageOracleT> Env<Fp, PreImageOracle> {
start: &Start,
) -> Instruction {
self.reset_scratch_state();
self.reset_scratch_state_inverse();
let (opcode, _instruction) = self.decode_instruction();

self.pp_info(&config.info_at, metadata, start);
Expand Down
28 changes: 20 additions & 8 deletions o1vm/src/pickles/column_env.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@ use ark_poly::{Evaluations, Radix2EvaluationDomain};
use kimchi_msm::columns::Column;

use crate::{
interpreters::mips::{column::N_MIPS_SEL_COLS, witness::SCRATCH_SIZE},
interpreters::mips::{
column::N_MIPS_SEL_COLS,
witness::{SCRATCH_SIZE, SCRATCH_SIZE_INVERSE},
},
pickles::proof::WitnessColumns,
};
use kimchi::circuits::{
Expand Down Expand Up @@ -36,8 +39,9 @@ pub struct ColumnEnvironment<'a, F: FftField> {
}

pub fn get_all_columns() -> Vec<Column> {
let mut cols = Vec::<Column>::with_capacity(SCRATCH_SIZE + 2 + N_MIPS_SEL_COLS);
for i in 0..SCRATCH_SIZE + 2 {
let mut cols =
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps use N_MIPS_REL_COLS, or you want to be super explicit here?

Copy link
Member Author

@dannywillems dannywillems Nov 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, you are right about using N_MIPS_REL_COLS.

Vec::<Column>::with_capacity(SCRATCH_SIZE + SCRATCH_SIZE_INVERSE + 2 + N_MIPS_SEL_COLS);
for i in 0..SCRATCH_SIZE + SCRATCH_SIZE_INVERSE + 2 {
cols.push(Column::Relation(i));
}
for i in 0..N_MIPS_SEL_COLS {
Expand All @@ -53,26 +57,34 @@ impl<G> WitnessColumns<G, [G; N_MIPS_SEL_COLS]> {
if i < SCRATCH_SIZE {
let res = &self.scratch[i];
Some(res)
} else if i == SCRATCH_SIZE {
} else if i < SCRATCH_SIZE + SCRATCH_SIZE_INVERSE {
let res = &self.scratch_inverse[i - SCRATCH_SIZE];
Some(res)
} else if i == SCRATCH_SIZE + SCRATCH_SIZE_INVERSE {
let res = &self.instruction_counter;
Some(res)
} else if i == SCRATCH_SIZE + 1 {
} else if i == SCRATCH_SIZE + SCRATCH_SIZE_INVERSE + 1 {
let res = &self.error;
Some(res)
} else {
panic!("We should not have that many relation columns");
panic!("We should not have that many relation columns. We have {} columns and index {} was given", SCRATCH_SIZE + SCRATCH_SIZE_INVERSE + 2, i);
}
}
Column::DynamicSelector(i) => {
assert!(
i < N_MIPS_SEL_COLS,
"We do not have that many dynamic selector columns"
"We do not have that many dynamic selector columns. We have {} columns and index {} was given",
N_MIPS_SEL_COLS,
i
);
let res = &self.selector[i];
Some(res)
}
_ => {
panic!("We should not have any other type of columns")
panic!(
"We should not have any other type of columns. The column {:?} was given",
col
);
}
}
}
Expand Down
7 changes: 7 additions & 0 deletions o1vm/src/pickles/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,13 @@ pub fn main() -> ExitCode {
{
scratch_chunk.push(*scratch);
}
for (scratch, scratch_chunk) in mips_wit_env
.scratch_state_inverse
.iter()
.zip(curr_proof_inputs.evaluations.scratch_inverse.iter_mut())
{
scratch_chunk.push(*scratch);
}
curr_proof_inputs
.evaluations
.instruction_counter
Expand Down
7 changes: 6 additions & 1 deletion o1vm/src/pickles/proof.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
use kimchi::{curve::KimchiCurve, proof::PointEvaluations};
use poly_commitment::{ipa::OpeningProof, PolyComm};

use crate::interpreters::mips::{column::N_MIPS_SEL_COLS, witness::SCRATCH_SIZE};
use crate::interpreters::mips::{
column::N_MIPS_SEL_COLS,
witness::{SCRATCH_SIZE, SCRATCH_SIZE_INVERSE},
};

pub struct WitnessColumns<G, S> {
pub scratch: [G; SCRATCH_SIZE],
pub scratch_inverse: [G; SCRATCH_SIZE_INVERSE],
pub instruction_counter: G,
pub error: G,
pub selector: S,
Expand All @@ -19,6 +23,7 @@ impl<G: KimchiCurve> ProofInputs<G> {
ProofInputs {
evaluations: WitnessColumns {
scratch: std::array::from_fn(|_| Vec::with_capacity(domain_size)),
scratch_inverse: std::array::from_fn(|_| Vec::with_capacity(domain_size)),
instruction_counter: Vec::with_capacity(domain_size),
error: Vec::with_capacity(domain_size),
selector: Vec::with_capacity(domain_size),
Expand Down
Loading
Loading