diff --git a/.github/workflows/benchmarks.yml b/.github/workflows/benchmarks.yml index ca9cef5..f99cfe5 100644 --- a/.github/workflows/benchmarks.yml +++ b/.github/workflows/benchmarks.yml @@ -23,7 +23,7 @@ env: jobs: start-runner: - runs-on: ubuntu-latest + runs-on: ubuntu-latest steps: - name: Configure AWS credentials uses: aws-actions/configure-aws-credentials@v1 @@ -160,12 +160,12 @@ jobs: BENCHY_OUTPUT_DIR: "../../.benchmarks/${{ matrix.runner }}" run: | cargo bench -F metal - + - name: Upload .benchmarks as artifact uses: actions/upload-artifact@v3 with: path: .benchmarks/ - + noir: runs-on: ${{ matrix.runner }} strategy: @@ -219,7 +219,6 @@ jobs: override: true - uses: Swatinem/rust-cache@v2 - - name: Run Leo benchmarks working-directory: ./leo env: @@ -231,7 +230,6 @@ jobs: uses: actions/upload-artifact@v3 with: path: .benchmarks/ - commit: runs-on: ubuntu-latest needs: [polylang, miden, risc_zero, noir] @@ -246,7 +244,6 @@ jobs: - name: Combine results run: ./combine.py - - name: Copy benchmark results to site run: | cp benchmarks.json site/src/fixtures/benchmarks.json @@ -270,10 +267,5 @@ jobs: aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }} aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }} aws-region: us-east-1 - - name: Stop EC2 instance run: aws ec2 stop-instances --instance-ids ${{ env.INSTANCE_ID }} - - - - diff --git a/Pipfile b/Pipfile new file mode 100644 index 0000000..74ab9f5 --- /dev/null +++ b/Pipfile @@ -0,0 +1,12 @@ +[[source]] +url = "https://pypi.org/simple" +verify_ssl = true +name = "pypi" + +[packages] +torch = "==2.1.0" + +[dev-packages] + +[requires] +python_version = "3.11" diff --git a/Pipfile.lock b/Pipfile.lock new file mode 100644 index 0000000..bfd889a --- /dev/null +++ b/Pipfile.lock @@ -0,0 +1,169 @@ +{ + "_meta": { + "hash": { + "sha256": "51dac6e8ac31c166718876a94f3237052b4bce1110c3c65f32d892be21148a50" + }, + "pipfile-spec": 6, + "requires": { + "python_version": "3.11" + }, + "sources": [ + { + "name": "pypi", + "url": "https://pypi.org/simple", + "verify_ssl": true + } + ] + }, + "default": { + "filelock": { + "hashes": [ + "sha256:521f5f56c50f8426f5e03ad3b281b490a87ef15bc6c526f168290f0c7148d44e", + "sha256:57dbda9b35157b05fb3e58ee91448612eb674172fab98ee235ccb0b5bee19a1c" + ], + "markers": "python_version >= '3.8'", + "version": "==3.13.1" + }, + "fsspec": { + "hashes": [ + "sha256:6271f1d3075a378bfe432f6f42bf7e1d2a6ba74f78dd9b512385474c579146a0", + "sha256:c4da01a35ac65c853f833e43f67802c25213f560820d54ddf248f92eddd5e990" + ], + "markers": "python_version >= '3.8'", + "version": "==2023.12.1" + }, + "jinja2": { + "hashes": [ + "sha256:31351a702a408a9e7595a8fc6150fc3f43bb6bf7e319770cbc0db9df9437e852", + "sha256:6088930bfe239f0e6710546ab9c19c9ef35e29792895fed6e6e31a023a182a61" + ], + "markers": "python_version >= '3.7'", + "version": "==3.1.2" + }, + "markupsafe": { + "hashes": [ + "sha256:05fb21170423db021895e1ea1e1f3ab3adb85d1c2333cbc2310f2a26bc77272e", + "sha256:0a4e4a1aff6c7ac4cd55792abf96c915634c2b97e3cc1c7129578aa68ebd754e", + "sha256:10bbfe99883db80bdbaff2dcf681dfc6533a614f700da1287707e8a5d78a8431", + "sha256:134da1eca9ec0ae528110ccc9e48041e0828d79f24121a1a146161103c76e686", + "sha256:14ff806850827afd6b07a5f32bd917fb7f45b046ba40c57abdb636674a8b559c", + "sha256:1577735524cdad32f9f694208aa75e422adba74f1baee7551620e43a3141f559", + "sha256:1b40069d487e7edb2676d3fbdb2b0829ffa2cd63a2ec26c4938b2d34391b4ecc", + "sha256:1b8dd8c3fd14349433c79fa8abeb573a55fc0fdd769133baac1f5e07abf54aeb", + "sha256:1f67c7038d560d92149c060157d623c542173016c4babc0c1913cca0564b9939", + "sha256:282c2cb35b5b673bbcadb33a585408104df04f14b2d9b01d4c345a3b92861c2c", + "sha256:2c1b19b3aaacc6e57b7e25710ff571c24d6c3613a45e905b1fde04d691b98ee0", + "sha256:2ef12179d3a291be237280175b542c07a36e7f60718296278d8593d21ca937d4", + "sha256:338ae27d6b8745585f87218a3f23f1512dbf52c26c28e322dbe54bcede54ccb9", + "sha256:3c0fae6c3be832a0a0473ac912810b2877c8cb9d76ca48de1ed31e1c68386575", + "sha256:3fd4abcb888d15a94f32b75d8fd18ee162ca0c064f35b11134be77050296d6ba", + "sha256:42de32b22b6b804f42c5d98be4f7e5e977ecdd9ee9b660fda1a3edf03b11792d", + "sha256:47d4f1c5f80fc62fdd7777d0d40a2e9dda0a05883ab11374334f6c4de38adffd", + "sha256:504b320cd4b7eff6f968eddf81127112db685e81f7e36e75f9f84f0df46041c3", + "sha256:525808b8019e36eb524b8c68acdd63a37e75714eac50e988180b169d64480a00", + "sha256:56d9f2ecac662ca1611d183feb03a3fa4406469dafe241673d521dd5ae92a155", + "sha256:5bbe06f8eeafd38e5d0a4894ffec89378b6c6a625ff57e3028921f8ff59318ac", + "sha256:65c1a9bcdadc6c28eecee2c119465aebff8f7a584dd719facdd9e825ec61ab52", + "sha256:68e78619a61ecf91e76aa3e6e8e33fc4894a2bebe93410754bd28fce0a8a4f9f", + "sha256:69c0f17e9f5a7afdf2cc9fb2d1ce6aabdb3bafb7f38017c0b77862bcec2bbad8", + "sha256:6b2b56950d93e41f33b4223ead100ea0fe11f8e6ee5f641eb753ce4b77a7042b", + "sha256:715d3562f79d540f251b99ebd6d8baa547118974341db04f5ad06d5ea3eb8007", + "sha256:787003c0ddb00500e49a10f2844fac87aa6ce977b90b0feaaf9de23c22508b24", + "sha256:7ef3cb2ebbf91e330e3bb937efada0edd9003683db6b57bb108c4001f37a02ea", + "sha256:8023faf4e01efadfa183e863fefde0046de576c6f14659e8782065bcece22198", + "sha256:8758846a7e80910096950b67071243da3e5a20ed2546e6392603c096778d48e0", + "sha256:8afafd99945ead6e075b973fefa56379c5b5c53fd8937dad92c662da5d8fd5ee", + "sha256:8c41976a29d078bb235fea9b2ecd3da465df42a562910f9022f1a03107bd02be", + "sha256:8e254ae696c88d98da6555f5ace2279cf7cd5b3f52be2b5cf97feafe883b58d2", + "sha256:8f9293864fe09b8149f0cc42ce56e3f0e54de883a9de90cd427f191c346eb2e1", + "sha256:9402b03f1a1b4dc4c19845e5c749e3ab82d5078d16a2a4c2cd2df62d57bb0707", + "sha256:962f82a3086483f5e5f64dbad880d31038b698494799b097bc59c2edf392fce6", + "sha256:9aad3c1755095ce347e26488214ef77e0485a3c34a50c5a5e2471dff60b9dd9c", + "sha256:9dcdfd0eaf283af041973bff14a2e143b8bd64e069f4c383416ecd79a81aab58", + "sha256:aa57bd9cf8ae831a362185ee444e15a93ecb2e344c8e52e4d721ea3ab6ef1823", + "sha256:aa7bd130efab1c280bed0f45501b7c8795f9fdbeb02e965371bbef3523627779", + "sha256:ab4a0df41e7c16a1392727727e7998a467472d0ad65f3ad5e6e765015df08636", + "sha256:ad9e82fb8f09ade1c3e1b996a6337afac2b8b9e365f926f5a61aacc71adc5b3c", + "sha256:af598ed32d6ae86f1b747b82783958b1a4ab8f617b06fe68795c7f026abbdcad", + "sha256:b076b6226fb84157e3f7c971a47ff3a679d837cf338547532ab866c57930dbee", + "sha256:b7ff0f54cb4ff66dd38bebd335a38e2c22c41a8ee45aa608efc890ac3e3931bc", + "sha256:bfce63a9e7834b12b87c64d6b155fdd9b3b96191b6bd334bf37db7ff1fe457f2", + "sha256:c011a4149cfbcf9f03994ec2edffcb8b1dc2d2aede7ca243746df97a5d41ce48", + "sha256:c9c804664ebe8f83a211cace637506669e7890fec1b4195b505c214e50dd4eb7", + "sha256:ca379055a47383d02a5400cb0d110cef0a776fc644cda797db0c5696cfd7e18e", + "sha256:cb0932dc158471523c9637e807d9bfb93e06a95cbf010f1a38b98623b929ef2b", + "sha256:cd0f502fe016460680cd20aaa5a76d241d6f35a1c3350c474bac1273803893fa", + "sha256:ceb01949af7121f9fc39f7d27f91be8546f3fb112c608bc4029aef0bab86a2a5", + "sha256:d080e0a5eb2529460b30190fcfcc4199bd7f827663f858a226a81bc27beaa97e", + "sha256:dd15ff04ffd7e05ffcb7fe79f1b98041b8ea30ae9234aed2a9168b5797c3effb", + "sha256:df0be2b576a7abbf737b1575f048c23fb1d769f267ec4358296f31c2479db8f9", + "sha256:e09031c87a1e51556fdcb46e5bd4f59dfb743061cf93c4d6831bf894f125eb57", + "sha256:e4dd52d80b8c83fdce44e12478ad2e85c64ea965e75d66dbeafb0a3e77308fcc", + "sha256:f698de3fd0c4e6972b92290a45bd9b1536bffe8c6759c62471efaa8acb4c37bc", + "sha256:fec21693218efe39aa7f8599346e90c705afa52c5b31ae019b2e57e8f6542bb2", + "sha256:ffcc3f7c66b5f5b7931a5aa68fc9cecc51e685ef90282f4a82f0f5e9b704ad11" + ], + "markers": "python_version >= '3.7'", + "version": "==2.1.3" + }, + "mpmath": { + "hashes": [ + "sha256:7a28eb2a9774d00c7bc92411c19a89209d5da7c4c9a9e227be8330a23a25b91f", + "sha256:a0b2b9fe80bbcd81a6647ff13108738cfb482d481d826cc0e02f5b35e5c88d2c" + ], + "version": "==1.3.0" + }, + "networkx": { + "hashes": [ + "sha256:9f1bb5cf3409bf324e0a722c20bdb4c20ee39bf1c30ce8ae499c8502b0b5e0c6", + "sha256:f18c69adc97877c42332c170849c96cefa91881c99a7cb3e95b7c659ebdc1ec2" + ], + "markers": "python_version >= '3.9'", + "version": "==3.2.1" + }, + "sympy": { + "hashes": [ + "sha256:c3588cd4295d0c0f603d0f2ae780587e64e2efeedb3521e46b9bb1d08d184fa5", + "sha256:ebf595c8dac3e0fdc4152c51878b498396ec7f30e7a914d6071e674d49420fb8" + ], + "markers": "python_version >= '3.8'", + "version": "==1.12" + }, + "torch": { + "hashes": [ + "sha256:05661c32ec14bc3a157193d0f19a7b19d8e61eb787b33353cad30202c295e83b", + "sha256:0bd691efea319b14ef239ede16d8a45c246916456fa3ed4f217d8af679433cc6", + "sha256:101c139152959cb20ab370fc192672c50093747906ee4ceace44d8dd703f29af", + "sha256:2224622407ca52611cbc5b628106fde22ed8e679031f5a99ce286629fc696128", + "sha256:2419cf49aaf3b2336c7aa7a54a1b949fa295b1ae36f77e2aecb3a74e3a947255", + "sha256:3cd1dedff13884d890f18eea620184fb4cd8fd3c68ce3300498f427ae93aa962", + "sha256:421739685eba5e0beba42cb649740b15d44b0d565c04e6ed667b41148734a75b", + "sha256:458a6d6d8f7d2ccc348ac4d62ea661b39a3592ad15be385bebd0a31ced7e00f4", + "sha256:556d8dd3e0c290ed9d4d7de598a213fb9f7c59135b4fee144364a8a887016a55", + "sha256:5c3bfa91ce25ba10116c224c59d5b64cdcce07161321d978bd5a1f15e1ebce72", + "sha256:601b0a2a9d9233fb4b81f7d47dca9680d4f3a78ca3f781078b6ad1ced8a90523", + "sha256:6ad491e70dbe4288d17fdbfc7fbfa766d66cbe219bc4871c7a8096f4a37c98df", + "sha256:761822761fffaa1c18a62c5deb13abaa780862577d3eadc428f1daa632536905", + "sha256:8132efb782cd181cc2dcca5e58effbe4217cdb2581206ac71466d535bf778867", + "sha256:a04a0296d47f28960f51c18c5489a8c3472f624ec3b5bcc8e2096314df8c3342", + "sha256:a6b7438a90a870e4cdeb15301519ae6c043c883fcd224d303c5b118082814767", + "sha256:bf57f8184b2c317ef81fb33dc233ce4d850cd98ef3f4a38be59c7c1572d175db", + "sha256:c8bf7eaf9514465e5d9101e05195183470a6215bb50295c61b52302a04edb690", + "sha256:de7d63c6ecece118684415a3dbd4805af4a4c1ee1490cccf7405d8c240a481b4", + "sha256:fb7bf0cc1a3db484eb5d713942a93172f3bac026fcb377a0cd107093d2eba777" + ], + "index": "pypi", + "markers": "python_full_version >= '3.8.0'", + "version": "==2.1.0" + }, + "typing-extensions": { + "hashes": [ + "sha256:8f92fc8806f9a6b641eaa5318da32b44d401efaac0f6678c9bc448ba3605faa0", + "sha256:df8e4339e9cb77357558cbdbceca33c303714cf861d1eef15e1070055ae8b7ef" + ], + "markers": "python_version >= '3.8'", + "version": "==4.8.0" + } + }, + "develop": {} +} diff --git a/mnist_ezkl/Cargo.lock b/mnist_ezkl/Cargo.lock new file mode 100644 index 0000000..8ca9b3b --- /dev/null +++ b/mnist_ezkl/Cargo.lock @@ -0,0 +1,6174 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 3 + +[[package]] +name = "Inflector" +version = "0.11.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fe438c63458706e03479442743baae6c88256498e6431708f6dfc520a26515d3" +dependencies = [ + "lazy_static", + "regex", +] + +[[package]] +name = "adler" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe" + +[[package]] +name = "aes" +version = "0.8.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ac1f845298e95f983ff1944b728ae08b8cebab80d684f0a832ed0fc74dfa27e2" +dependencies = [ + "cfg-if", + "cipher", + "cpufeatures", +] + +[[package]] +name = "ahash" +version = "0.7.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fcb51a0695d8f838b1ee009b3fbf66bda078cd64590202a864a8f3e8c4315c47" +dependencies = [ + "getrandom", + "once_cell", + "version_check", +] + +[[package]] +name = "ahash" +version = "0.8.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2c99f64d1e06488f620f932677e24bc6e2897582980441ae90a671415bd7ec2f" +dependencies = [ + "cfg-if", + "once_cell", + "version_check", +] + +[[package]] +name = "aho-corasick" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "43f6cb1bf222025340178f382c426f13757b2960e89779dfcb319c32542a5a41" +dependencies = [ + "memchr", +] + +[[package]] +name = "alloy-rlp" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cc0fac0fc16baf1f63f78b47c3d24718f3619b0714076f6a02957d808d52cbef" +dependencies = [ + "arrayvec 0.7.4", + "bytes", + "smol_str", +] + +[[package]] +name = "android-tzdata" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e999941b234f3131b00bc13c22d06e8c5ff726d1b6318ac7eb276997bbb4fef0" + +[[package]] +name = "ansi-str" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1cf4578926a981ab0ca955dc023541d19de37112bc24c1a197bd806d3d86ad1d" +dependencies = [ + "ansitok", +] + +[[package]] +name = "ansitok" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "220044e6a1bb31ddee4e3db724d29767f352de47445a6cd75e1a173142136c83" +dependencies = [ + "nom", + "vte", +] + +[[package]] +name = "anstream" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0ca84f3628370c59db74ee214b3263d58f9aadd9b4fe7e711fd87dc452b7f163" +dependencies = [ + "anstyle", + "anstyle-parse", + "anstyle-query", + "anstyle-wincon", + "colorchoice", + "is-terminal", + "utf8parse", +] + +[[package]] +name = "anstyle" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3a30da5c5f2d5e72842e00bcb57657162cdabef0931f40e2deb9b4140440cecd" + +[[package]] +name = "anstyle-parse" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "938874ff5980b03a87c5524b3ae5b59cf99b1d6bc836848df7bc5ada9643c333" +dependencies = [ + "utf8parse", +] + +[[package]] +name = "anstyle-query" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5ca11d4be1bab0c8bc8734a9aa7bf4ee8316d462a08c6ac5052f888fef5b494b" +dependencies = [ + "windows-sys 0.48.0", +] + +[[package]] +name = "anstyle-wincon" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "180abfa45703aebe0093f79badacc01b8fd4ea2e35118747e5811127f926e188" +dependencies = [ + "anstyle", + "windows-sys 0.48.0", +] + +[[package]] +name = "anyhow" +version = "1.0.71" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9c7d0618f0e0b7e8ff11427422b64564d5fb0be1940354bfe2e0529b18a9d9b8" + +[[package]] +name = "anymap2" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d301b3b94cb4b2f23d7917810addbbaff90738e0ca2be692bd027e70d7e0330c" + +[[package]] +name = "arc-swap" +version = "1.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bddcadddf5e9015d310179a59bb28c4d4b9920ad0f11e8e14dbadf654890c9a6" + +[[package]] +name = "ark-bls12-377" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc41c02c0d18a226947ee9ee023b1d957bdb6a68fc22ac296722935a9fef423c" +dependencies = [ + "ark-ec", + "ark-ff 0.3.0", + "ark-std 0.3.0", +] + +[[package]] +name = "ark-bls12-381" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "65be532f9dd1e98ad0150b037276cde464c6f371059e6dd02c0222395761f6aa" +dependencies = [ + "ark-ec", + "ark-ff 0.3.0", + "ark-std 0.3.0", +] + +[[package]] +name = "ark-bn254" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ea691771ebbb28aea556c044e2e5c5227398d840cee0c34d4d20fa8eb2689e8c" +dependencies = [ + "ark-ec", + "ark-ff 0.3.0", + "ark-std 0.3.0", +] + +[[package]] +name = "ark-ec" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dea978406c4b1ca13c2db2373b05cc55429c3575b8b21f1b9ee859aa5b03dd42" +dependencies = [ + "ark-ff 0.3.0", + "ark-serialize 0.3.0", + "ark-std 0.3.0", + "derivative", + "num-traits", + "rayon", + "zeroize", +] + +[[package]] +name = "ark-ff" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6b3235cc41ee7a12aaaf2c575a2ad7b46713a8a50bda2fc3b003a04845c05dd6" +dependencies = [ + "ark-ff-asm 0.3.0", + "ark-ff-macros 0.3.0", + "ark-serialize 0.3.0", + "ark-std 0.3.0", + "derivative", + "num-bigint", + "num-traits", + "paste", + "rustc_version 0.3.3", + "zeroize", +] + +[[package]] +name = "ark-ff" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec847af850f44ad29048935519032c33da8aa03340876d351dfab5660d2966ba" +dependencies = [ + "ark-ff-asm 0.4.2", + "ark-ff-macros 0.4.2", + "ark-serialize 0.4.2", + "ark-std 0.4.0", + "derivative", + "digest 0.10.7", + "itertools 0.10.5", + "num-bigint", + "num-traits", + "paste", + "rustc_version 0.4.0", + "zeroize", +] + +[[package]] +name = "ark-ff-asm" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "db02d390bf6643fb404d3d22d31aee1c4bc4459600aef9113833d17e786c6e44" +dependencies = [ + "quote", + "syn 1.0.109", +] + +[[package]] +name = "ark-ff-asm" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3ed4aa4fe255d0bc6d79373f7e31d2ea147bcf486cba1be5ba7ea85abdb92348" +dependencies = [ + "quote", + "syn 1.0.109", +] + +[[package]] +name = "ark-ff-macros" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "db2fd794a08ccb318058009eefdf15bcaaaaf6f8161eb3345f907222bac38b20" +dependencies = [ + "num-bigint", + "num-traits", + "quote", + "syn 1.0.109", +] + +[[package]] +name = "ark-ff-macros" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7abe79b0e4288889c4574159ab790824d0033b9fdcb2a112a3182fac2e514565" +dependencies = [ + "num-bigint", + "num-traits", + "proc-macro2", + "quote", + "syn 1.0.109", +] + +[[package]] +name = "ark-poly" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7b0f78f47537c2f15706db7e98fe64cc1711dbf9def81218194e17239e53e5aa" +dependencies = [ + "ark-ff 0.3.0", + "ark-serialize 0.3.0", + "ark-std 0.3.0", + "derivative", + "hashbrown 0.11.2", +] + +[[package]] +name = "ark-serialize" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d6c2b318ee6e10f8c2853e73a83adc0ccb88995aa978d8a3408d492ab2ee671" +dependencies = [ + "ark-serialize-derive", + "ark-std 0.3.0", + "digest 0.9.0", +] + +[[package]] +name = "ark-serialize" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "adb7b85a02b83d2f22f89bd5cac66c9c89474240cb6207cb1efc16d098e822a5" +dependencies = [ + "ark-std 0.4.0", + "digest 0.10.7", + "num-bigint", +] + +[[package]] +name = "ark-serialize-derive" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8dd4e5f0bf8285d5ed538d27fab7411f3e297908fd93c62195de8bee3f199e82" +dependencies = [ + "proc-macro2", + "quote", + "syn 1.0.109", +] + +[[package]] +name = "ark-std" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1df2c09229cbc5a028b1d70e00fdb2acee28b1055dfb5ca73eea49c5a25c4e7c" +dependencies = [ + "num-traits", + "rand 0.8.5", + "rayon", +] + +[[package]] +name = "ark-std" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94893f1e0c6eeab764ade8dc4c0db24caf4fe7cbbaafc0eba0a9030f447b5185" +dependencies = [ + "num-traits", + "rand 0.8.5", +] + +[[package]] +name = "arrayref" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6b4930d2cb77ce62f89ee5d5289b4ac049559b1c45539271f5ed4fdc7db34545" + +[[package]] +name = "arrayvec" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "23b62fc65de8e4e7f52534fb52b0f3ed04746ae267519eef2a83941e8085068b" + +[[package]] +name = "arrayvec" +version = "0.7.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "96d30a06541fbafbc7f82ed10c06164cfbd2c401138f6addd8404629c4b16711" + +[[package]] +name = "ascii-canvas" +version = "3.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8824ecca2e851cec16968d54a01dd372ef8f95b244fb84b84e70128be347c3c6" +dependencies = [ + "term", +] + +[[package]] +name = "askama" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b79091df18a97caea757e28cd2d5fda49c6cd4bd01ddffd7ff01ace0c0ad2c28" +dependencies = [ + "askama_derive", + "askama_escape", +] + +[[package]] +name = "askama_derive" +version = "0.12.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9a0fc7dcf8bd4ead96b1d36b41df47c14beedf7b0301fc543d8f2384e66a2ec0" +dependencies = [ + "askama_parser", + "basic-toml", + "mime", + "mime_guess", + "proc-macro2", + "quote", + "serde", + "syn 2.0.22", +] + +[[package]] +name = "askama_escape" +version = "0.10.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "619743e34b5ba4e9703bba34deac3427c72507c7159f5fd030aea8cac0cfe341" + +[[package]] +name = "askama_parser" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c268a96e01a4c47c8c5c2472aaa570707e006a875ea63e819f75474ceedaf7b4" +dependencies = [ + "nom", +] + +[[package]] +name = "async-trait" +version = "0.1.68" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b9ccdd8f2a161be9bd5c023df56f1b2a0bd1d83872ae53b71a84a12c9bf6e842" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.22", +] + +[[package]] +name = "async_io_stream" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6d7b9decdf35d8908a7e3ef02f64c5e9b1695e230154c0e8de3969142d9b94c" +dependencies = [ + "futures", + "pharos", + "rustc_version 0.4.0", +] + +[[package]] +name = "atty" +version = "0.2.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d9b39be18770d11421cdb1b9947a45dd3f37e93092cbf377614828a319d5fee8" +dependencies = [ + "hermit-abi 0.1.19", + "libc", + "winapi", +] + +[[package]] +name = "auto_impl" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fee3da8ef1276b0bee5dd1c7258010d8fffd31801447323115a25560e1327b89" +dependencies = [ + "proc-macro-error", + "proc-macro2", + "quote", + "syn 1.0.109", +] + +[[package]] +name = "autocfg" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" + +[[package]] +name = "base16ct" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4c7f02d4ea65f2c1853089ffd8d2787bdbc63de2f0d29dedbcf8ccdfa0ccd4cf" + +[[package]] +name = "base64" +version = "0.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9e1b586273c5702936fe7b7d6896644d8be71e6314cfe09d3167c95f712589e8" + +[[package]] +name = "base64" +version = "0.21.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "604178f6c5c21f02dc555784810edfb88d34ac2c73b2eae109655649ee73ce3d" + +[[package]] +name = "base64ct" +version = "1.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8c3c1a368f70d6cf7302d78f8f7093da241fb8e8807c05cc9e51a125895a6d5b" + +[[package]] +name = "basic-toml" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c0de75129aa8d0cceaf750b89013f0e08804d6ec61416da787b35ad0d7cddf1" +dependencies = [ + "serde", +] + +[[package]] +name = "bech32" +version = "0.7.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2dabbe35f96fb9507f7330793dc490461b2962659ac5d427181e451a623751d1" + +[[package]] +name = "benchy" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a02b4df38b8de5ea77a5caee888d4dfaa50044b01ad3609f14bcbe93d4d609f3" +dependencies = [ + "benchy-macros", + "memory-stats", + "nix", + "serde", + "serde_json", +] + +[[package]] +name = "benchy-macros" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "24c27b215b0291aee9ce4167bca78bf1ecbaea93dae3020b9542c45843f0eda4" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.22", +] + +[[package]] +name = "bigdecimal" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a6773ddc0eafc0e509fb60e48dff7f450f8e674a0686ae8605e8d9901bd5eefa" +dependencies = [ + "num-bigint", + "num-integer", + "num-traits", +] + +[[package]] +name = "bincode" +version = "1.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b1f45e9417d87227c7a56d22e471c6206462cba514c7590c09aff4cf6d1ddcad" +dependencies = [ + "serde", +] + +[[package]] +name = "bit-set" +version = "0.5.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0700ddab506f33b20a03b13996eccd309a48e5ff77d0d95926aa0210fb4e95f1" +dependencies = [ + "bit-vec", +] + +[[package]] +name = "bit-vec" +version = "0.6.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "349f9b6a179ed607305526ca489b34ad0a41aed5f7980fa90eb03160b69598fb" + +[[package]] +name = "bitflags" +version = "1.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" + +[[package]] +name = "bitflags" +version = "2.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b4682ae6287fcf752ecaabbfcc7b6f9b72aa33933dc23a554d853aea8eea8635" + +[[package]] +name = "bitvec" +version = "0.17.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "41262f11d771fd4a61aa3ce019fca363b4b6c282fca9da2a31186d3965a47a5c" +dependencies = [ + "either", + "radium 0.3.0", +] + +[[package]] +name = "bitvec" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1bc2832c24239b0141d5674bb9174f9d68a8b5b3f2753311927c172ca46f7e9c" +dependencies = [ + "funty", + "radium 0.7.0", + "tap", + "wyz", +] + +[[package]] +name = "blake2b_simd" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3c2f0dc9a68c6317d884f97cc36cf5a3d20ba14ce404227df55e1af708ab04bc" +dependencies = [ + "arrayref", + "arrayvec 0.7.4", + "constant_time_eq 0.2.6", +] + +[[package]] +name = "block-buffer" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4152116fd6e9dadb291ae18fc1ec3575ed6d84c29642d97890f4b4a3417297e4" +dependencies = [ + "block-padding", + "generic-array", +] + +[[package]] +name = "block-buffer" +version = "0.10.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3078c7629b62d3f0439517fa394996acacc5cbc91c5a20d8c658e77abd503a71" +dependencies = [ + "generic-array", +] + +[[package]] +name = "block-padding" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8d696c370c750c948ada61c69a0ee2cbbb9c50b1019ddb86d9317157a99c2cae" + +[[package]] +name = "bs58" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "771fe0050b883fcc3ea2359b1a96bcfbc090b7116eae7c3c512c7a083fdf23d3" +dependencies = [ + "sha2 0.9.9", +] + +[[package]] +name = "bumpalo" +version = "3.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a3e2c3daef883ecc1b5d58c15adae93470a91d425f3532ba1695849656af3fc1" + +[[package]] +name = "byte-slice-cast" +version = "1.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c3ac9f8b63eca6fd385229b3675f6cc0dc5c8a5c8a54a59d4f52ffd670d87b0c" + +[[package]] +name = "bytecount" +version = "0.6.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2c676a478f63e9fa2dd5368a42f28bba0d6c560b775f38583c8bbaa7fcd67c9c" + +[[package]] +name = "byteorder" +version = "1.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "14c189c53d098945499cdfa7ecc63567cf3886b3332b312a5b4585d8d3a6a610" + +[[package]] +name = "bytes" +version = "1.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "89b2fd2a0dcf38d7971e2194b6b6eebab45ae01067456a7fd93d5547a61b70be" +dependencies = [ + "serde", +] + +[[package]] +name = "bzip2" +version = "0.4.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bdb116a6ef3f6c3698828873ad02c3014b3c85cadb88496095628e3ef1e347f8" +dependencies = [ + "bzip2-sys", + "libc", +] + +[[package]] +name = "bzip2-sys" +version = "0.1.11+1.0.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "736a955f3fa7875102d57c82b8cac37ec45224a07fd32d58f9f7a186b6cd4cdc" +dependencies = [ + "cc", + "libc", + "pkg-config", +] + +[[package]] +name = "camino" +version = "1.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c530edf18f37068ac2d977409ed5cd50d53d73bc653c7647b48eb78976ac9ae2" +dependencies = [ + "serde", +] + +[[package]] +name = "cargo-platform" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cbdb825da8a5df079a43676dbe042702f1707b1109f713a01420fbb4cc71fa27" +dependencies = [ + "serde", +] + +[[package]] +name = "cargo_metadata" +version = "0.15.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eee4243f1f26fc7a42710e7439c149e2b10b05472f88090acce52632f231a73a" +dependencies = [ + "camino", + "cargo-platform", + "semver 1.0.17", + "serde", + "serde_json", + "thiserror", +] + +[[package]] +name = "cast" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5" + +[[package]] +name = "cc" +version = "1.0.79" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "50d30906286121d95be3d479533b458f87493b30a4b5f79a607db8f5d11aa91f" +dependencies = [ + "jobserver", +] + +[[package]] +name = "cfg-if" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" + +[[package]] +name = "chrono" +version = "0.4.26" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec837a71355b28f6556dbd569b37b3f363091c0bd4b2e735674521b4c5fd9bc5" +dependencies = [ + "android-tzdata", + "num-traits", +] + +[[package]] +name = "cipher" +version = "0.4.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "773f3b9af64447d2ce9850330c473515014aa235e6a783b02db81ff39e4a3dad" +dependencies = [ + "crypto-common", + "inout", +] + +[[package]] +name = "clap" +version = "2.34.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a0610544180c38b88101fecf2dd634b174a62eef6946f84dfc6a7127512b381c" +dependencies = [ + "bitflags 1.3.2", + "textwrap", + "unicode-width", +] + +[[package]] +name = "clap" +version = "4.3.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d9394150f5b4273a1763355bd1c2ec54cc5a2593f790587bcd6b2c947cfa9211" +dependencies = [ + "clap_builder", + "clap_derive", + "once_cell", +] + +[[package]] +name = "clap_builder" +version = "4.3.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9a78fbdd3cc2914ddf37ba444114bc7765bbdcb55ec9cbe6fa054f0137400717" +dependencies = [ + "anstream", + "anstyle", + "bitflags 1.3.2", + "clap_lex", + "strsim 0.10.0", +] + +[[package]] +name = "clap_derive" +version = "4.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b8cd2b2a819ad6eec39e8f1d6b53001af1e5469f8c177579cdaeb313115b825f" +dependencies = [ + "heck", + "proc-macro2", + "quote", + "syn 2.0.22", +] + +[[package]] +name = "clap_lex" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2da6da31387c7e4ef160ffab6d5e7f00c42626fe39aea70a7b0f1773f7dd6c1b" + +[[package]] +name = "coins-bip32" +version = "0.8.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b30a84aab436fcb256a2ab3c80663d8aec686e6bae12827bb05fef3e1e439c9f" +dependencies = [ + "bincode", + "bs58", + "coins-core", + "digest 0.10.7", + "getrandom", + "hmac", + "k256", + "lazy_static", + "serde", + "sha2 0.10.7", + "thiserror", +] + +[[package]] +name = "coins-bip39" +version = "0.8.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "84f4d04ee18e58356accd644896aeb2094ddeafb6a713e056cef0c0a8e468c15" +dependencies = [ + "bitvec 0.17.4", + "coins-bip32", + "getrandom", + "hmac", + "once_cell", + "pbkdf2 0.12.1", + "rand 0.8.5", + "sha2 0.10.7", + "thiserror", +] + +[[package]] +name = "coins-core" +version = "0.8.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9b949a1c63fb7eb591eb7ba438746326aedf0ae843e51ec92ba6bec5bb382c4f" +dependencies = [ + "base64 0.21.2", + "bech32", + "bs58", + "digest 0.10.7", + "generic-array", + "hex", + "ripemd", + "serde", + "serde_derive", + "sha2 0.10.7", + "sha3 0.10.8", + "thiserror", +] + +[[package]] +name = "colorchoice" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "acbf1af155f9b9ef647e42cdc158db4b64a1b61f743629225fde6f3e0be2a7c7" + +[[package]] +name = "colored" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b3616f750b84d8f0de8a58bda93e08e2a81ad3f523089b05f1dffecab48c6cbd" +dependencies = [ + "atty", + "lazy_static", + "winapi", +] + +[[package]] +name = "colored_json" +version = "3.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "74cb9ce6b86f6e54bfa9518df2eeeef65d424ec7244d083ed97229185e366a91" +dependencies = [ + "is-terminal", + "serde", + "serde_json", + "yansi", +] + +[[package]] +name = "console" +version = "0.15.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c926e00cc70edefdc64d3a5ff31cc65bb97a3460097762bd23afb4d8145fccf8" +dependencies = [ + "encode_unicode", + "lazy_static", + "libc", + "unicode-width", + "windows-sys 0.45.0", +] + +[[package]] +name = "console_error_panic_hook" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a06aeb73f470f66dcdbf7223caeebb85984942f22f1adb2a088cf9668146bbbc" +dependencies = [ + "cfg-if", + "wasm-bindgen", +] + +[[package]] +name = "const-oid" +version = "0.9.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "520fbf3c07483f94e3e3ca9d0cfd913d7718ef2483d2cfd91c0d9e91474ab913" + +[[package]] +name = "constant_time_eq" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "245097e9a4535ee1e3e3931fcfcd55a796a44c643e8596ff6566d68f09b87bbc" + +[[package]] +name = "constant_time_eq" +version = "0.2.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "21a53c0a4d288377e7415b53dcfc3c04da5cdc2cc95c8d5ac178b58f0b861ad6" + +[[package]] +name = "convert_case" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6245d59a3e82a7fc217c5828a6692dbc6dfb63a0c8c90495621f7b9d79704a0e" + +[[package]] +name = "core-foundation" +version = "0.9.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "194a7a9e6de53fa55116934067c844d9d749312f75c6f6d0980e8c252f8c2146" +dependencies = [ + "core-foundation-sys", + "libc", +] + +[[package]] +name = "core-foundation-sys" +version = "0.8.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e496a50fda8aacccc86d7529e2c1e0892dbd0f898a6b5645b5561b89c3210efa" + +[[package]] +name = "cpufeatures" +version = "0.2.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "03e69e28e9f7f77debdedbaafa2866e1de9ba56df55a8bd7cfc724c25a09987c" +dependencies = [ + "libc", +] + +[[package]] +name = "crc32fast" +version = "1.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b540bd8bc810d3885c6ea91e2018302f68baba2129ab3e88f32389ee9370880d" +dependencies = [ + "cfg-if", +] + +[[package]] +name = "criterion" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b01d6de93b2b6c65e17c634a26653a29d107b3c98c607c765bf38d041531cd8f" +dependencies = [ + "atty", + "cast", + "clap 2.34.0", + "criterion-plot", + "csv", + "itertools 0.10.5", + "lazy_static", + "num-traits", + "oorandom", + "plotters", + "rayon", + "regex", + "serde", + "serde_cbor", + "serde_derive", + "serde_json", + "tinytemplate", + "walkdir", +] + +[[package]] +name = "criterion-plot" +version = "0.4.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2673cc8207403546f45f5fd319a974b1e6983ad1a3ee7e6041650013be041876" +dependencies = [ + "cast", + "itertools 0.10.5", +] + +[[package]] +name = "crossbeam-channel" +version = "0.5.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a33c2bf77f2df06183c3aa30d1e96c0695a313d4f9c453cc3762a6db39f99200" +dependencies = [ + "cfg-if", + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-deque" +version = "0.8.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ce6fd6f855243022dcecf8702fef0c297d4338e226845fe067f6341ad9fa0cef" +dependencies = [ + "cfg-if", + "crossbeam-epoch", + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-epoch" +version = "0.9.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ae211234986c545741a7dc064309f67ee1e5ad243d0e48335adc0484d960bcc7" +dependencies = [ + "autocfg", + "cfg-if", + "crossbeam-utils", + "memoffset 0.9.0", + "scopeguard", +] + +[[package]] +name = "crossbeam-utils" +version = "0.8.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a22b2d63d4d1dc0b7f1b6b2747dd0088008a9be28b6ddf0b1e7d335e3037294" +dependencies = [ + "cfg-if", +] + +[[package]] +name = "crunchy" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7a81dae078cea95a014a339291cec439d2f232ebe854a9d672b796c6afafa9b7" + +[[package]] +name = "crypto-bigint" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cf4c2f4e1afd912bc40bfd6fed5d9dc1f288e0ba01bfcc835cc5bc3eb13efe15" +dependencies = [ + "generic-array", + "rand_core 0.6.4", + "subtle", + "zeroize", +] + +[[package]] +name = "crypto-common" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1bfb12502f3fc46cca1bb51ac28df9d618d813cdc3d2f25b9fe775a34af26bb3" +dependencies = [ + "generic-array", + "typenum", +] + +[[package]] +name = "csv" +version = "1.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "626ae34994d3d8d668f4269922248239db4ae42d538b14c398b74a52208e8086" +dependencies = [ + "csv-core", + "itoa", + "ryu", + "serde", +] + +[[package]] +name = "csv-core" +version = "0.1.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2b2466559f260f48ad25fe6317b3c8dac77b5bdb5763ac7d9d6103530663bc90" +dependencies = [ + "memchr", +] + +[[package]] +name = "ctr" +version = "0.9.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0369ee1ad671834580515889b80f2ea915f23b8be8d0daa4bbaf2ac5c7590835" +dependencies = [ + "cipher", +] + +[[package]] +name = "cuda-config" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4ee74643f7430213a1a78320f88649de309b20b80818325575e393f848f79f5d" +dependencies = [ + "glob", +] + +[[package]] +name = "cuda-driver-sys" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d4c552cc0de854877d80bcd1f11db75d42be32962d72a6799b88dcca88fffbd" +dependencies = [ + "cuda-config", +] + +[[package]] +name = "darling" +version = "0.10.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0d706e75d87e35569db781a9b5e2416cff1236a47ed380831f959382ccd5f858" +dependencies = [ + "darling_core", + "darling_macro", +] + +[[package]] +name = "darling_core" +version = "0.10.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f0c960ae2da4de88a91b2d920c2a7233b400bc33cb28453a2987822d8392519b" +dependencies = [ + "fnv", + "ident_case", + "proc-macro2", + "quote", + "strsim 0.9.3", + "syn 1.0.109", +] + +[[package]] +name = "darling_macro" +version = "0.10.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d9b5a2f4ac4969822c62224815d069952656cadc7084fdca9751e6d959189b72" +dependencies = [ + "darling_core", + "quote", + "syn 1.0.109", +] + +[[package]] +name = "der" +version = "0.7.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "56acb310e15652100da43d130af8d97b509e95af61aab1c5a7939ef24337ee17" +dependencies = [ + "const-oid", + "zeroize", +] + +[[package]] +name = "derivative" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fcc3dd5e9e9c0b295d6e1e4d811fb6f157d5ffd784b8d202fc62eac8035a770b" +dependencies = [ + "proc-macro2", + "quote", + "syn 1.0.109", +] + +[[package]] +name = "derive-new" +version = "0.5.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3418329ca0ad70234b9735dc4ceed10af4df60eff9c8e7b06cb5e520d92c3535" +dependencies = [ + "proc-macro2", + "quote", + "syn 1.0.109", +] + +[[package]] +name = "derive_builder" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a2658621297f2cf68762a6f7dc0bb7e1ff2cfd6583daef8ee0fed6f7ec468ec0" +dependencies = [ + "darling", + "derive_builder_core", + "proc-macro2", + "quote", + "syn 1.0.109", +] + +[[package]] +name = "derive_builder_core" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2791ea3e372c8495c0bc2033991d76b512cd799d07491fbd6890124db9458bef" +dependencies = [ + "darling", + "proc-macro2", + "quote", + "syn 1.0.109", +] + +[[package]] +name = "derive_more" +version = "0.99.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4fb810d30a7c1953f91334de7244731fc3f3c10d7fe163338a35b9f640960321" +dependencies = [ + "convert_case", + "proc-macro2", + "quote", + "rustc_version 0.4.0", + "syn 1.0.109", +] + +[[package]] +name = "diff" +version = "0.1.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "56254986775e3233ffa9c4d7d3faaf6d36a2c09d30b20687e9f88bc8bafc16c8" + +[[package]] +name = "digest" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d3dd60d1080a57a05ab032377049e0591415d2b31afd7028356dbf3cc6dcb066" +dependencies = [ + "generic-array", +] + +[[package]] +name = "digest" +version = "0.10.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292" +dependencies = [ + "block-buffer 0.10.4", + "const-oid", + "crypto-common", + "subtle", +] + +[[package]] +name = "dirs" +version = "5.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "44c45a9d03d6676652bcb5e724c7e988de1acad23a711b5217ab9cbecbec2225" +dependencies = [ + "dirs-sys", +] + +[[package]] +name = "dirs-next" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b98cf8ebf19c3d1b223e151f99a4f9f0690dca41414773390fc824184ac833e1" +dependencies = [ + "cfg-if", + "dirs-sys-next", +] + +[[package]] +name = "dirs-sys" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "520f05a5cbd335fae5a99ff7a6ab8627577660ee5cfd6a94a6a929b52ff0321c" +dependencies = [ + "libc", + "option-ext", + "redox_users", + "windows-sys 0.48.0", +] + +[[package]] +name = "dirs-sys-next" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4ebda144c4fe02d1f7ea1a7d9641b6fc6b580adcfa024ae48797ecdeb6825b4d" +dependencies = [ + "libc", + "redox_users", + "winapi", +] + +[[package]] +name = "doc-comment" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fea41bba32d969b513997752735605054bc0dfa92b4c56bf1189f2e174be7a10" + +[[package]] +name = "downcast-rs" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ea835d29036a4087793836fa931b08837ad5e957da9e23886b29586fb9b6650" + +[[package]] +name = "dunce" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "56ce8c6da7551ec6c462cbaf3bfbc75131ebbfa1c944aeaa9dab51ca1c5f0c3b" + +[[package]] +name = "dyn-clone" +version = "1.0.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "68b0cf012f1230e43cd00ebb729c6bb58707ecfa8ad08b52ef3a4ccd2697fc30" + +[[package]] +name = "ecc" +version = "0.1.0" +source = "git+https://github.com/zkonduit/halo2wrong?branch=ac/chunked-mv-lookup#c1d7551c82953829caee30fe218759b0d2657d26" +dependencies = [ + "integer", + "num-bigint", + "num-integer", + "num-traits", + "rand 0.8.5", + "subtle", +] + +[[package]] +name = "ecdsa" +version = "0.16.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0997c976637b606099b9985693efa3581e84e41f5c11ba5255f88711058ad428" +dependencies = [ + "der", + "digest 0.10.7", + "elliptic-curve", + "rfc6979", + "signature", + "spki", +] + +[[package]] +name = "either" +version = "1.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7fcaabb2fef8c910e7f4c7ce9f67a1283a1715879a7c230ca9d6d1ae31f16d91" + +[[package]] +name = "elliptic-curve" +version = "0.13.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "968405c8fdc9b3bf4df0a6638858cc0b52462836ab6b1c87377785dd09cf1c0b" +dependencies = [ + "base16ct", + "crypto-bigint", + "digest 0.10.7", + "ff", + "generic-array", + "group", + "pkcs8", + "rand_core 0.6.4", + "sec1", + "subtle", + "zeroize", +] + +[[package]] +name = "ena" +version = "0.14.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c533630cf40e9caa44bd91aadc88a75d75a4c3a12b4cfde353cbed41daa1e1f1" +dependencies = [ + "log", +] + +[[package]] +name = "encode_unicode" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a357d28ed41a50f9c765dbfe56cbc04a64e53e5fc58ba79fbc34c10ef3df831f" + +[[package]] +name = "encoding_rs" +version = "0.8.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "071a31f4ee85403370b58aca746f01041ede6f0da2730960ad001edc2b71b394" +dependencies = [ + "cfg-if", +] + +[[package]] +name = "enr" +version = "0.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cf56acd72bb22d2824e66ae8e9e5ada4d0de17a69c7fd35569dde2ada8ec9116" +dependencies = [ + "base64 0.13.1", + "bytes", + "hex", + "k256", + "log", + "rand 0.8.5", + "rlp", + "serde", + "sha3 0.10.8", + "zeroize", +] + +[[package]] +name = "enumn" +version = "0.1.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "48016319042fb7c87b78d2993084a831793a897a5cd1a2a67cab9d1eeb4b7d76" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.22", +] + +[[package]] +name = "env_logger" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85cdab6a89accf66733ad5a1693a4dcced6aeff64602b634530dd73c1f3ee9f0" +dependencies = [ + "humantime", + "is-terminal", + "log", + "regex", + "termcolor", +] + +[[package]] +name = "equivalent" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "88bffebc5d80432c9b140ee17875ff173a8ab62faad5b257da912bd2f6c1c0a1" + +[[package]] +name = "errno" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4bcfec3a70f97c962c307b2d2c56e358cf1d00b558d74262b5f929ee8cc7e73a" +dependencies = [ + "errno-dragonfly", + "libc", + "windows-sys 0.48.0", +] + +[[package]] +name = "errno-dragonfly" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "aa68f1b12764fab894d2755d2518754e71b4fd80ecfb822714a1206c2aab39bf" +dependencies = [ + "cc", + "libc", +] + +[[package]] +name = "eth-keystore" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1fda3bf123be441da5260717e0661c25a2fd9cb2b2c1d20bf2e05580047158ab" +dependencies = [ + "aes", + "ctr", + "digest 0.10.7", + "hex", + "hmac", + "pbkdf2 0.11.0", + "rand 0.8.5", + "scrypt", + "serde", + "serde_json", + "sha2 0.10.7", + "sha3 0.10.8", + "thiserror", + "uuid", +] + +[[package]] +name = "ethabi" +version = "18.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7413c5f74cc903ea37386a8965a936cbeb334bd270862fdece542c1b2dcbc898" +dependencies = [ + "ethereum-types", + "hex", + "once_cell", + "regex", + "serde", + "serde_json", + "sha3 0.10.8", + "thiserror", + "uint", +] + +[[package]] +name = "ethbloom" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c22d4b5885b6aa2fe5e8b9329fb8d232bf739e434e6b87347c63bdd00c120f60" +dependencies = [ + "crunchy", + "fixed-hash", + "impl-codec", + "impl-rlp", + "impl-serde", + "scale-info", + "tiny-keccak", +] + +[[package]] +name = "ethereum-types" +version = "0.14.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "02d215cbf040552efcbe99a38372fe80ab9d00268e20012b79fcd0f073edd8ee" +dependencies = [ + "ethbloom", + "fixed-hash", + "impl-codec", + "impl-rlp", + "impl-serde", + "primitive-types", + "scale-info", + "uint", +] + +[[package]] +name = "ethers" +version = "2.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2a58ce802c65cf3d0756dee5a61094a92cde53c1583b246e9ee5b37226c7fc15" +dependencies = [ + "ethers-addressbook", + "ethers-contract", + "ethers-core", + "ethers-etherscan", + "ethers-middleware", + "ethers-providers", + "ethers-signers", + "ethers-solc", +] + +[[package]] +name = "ethers-addressbook" +version = "2.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7b856b7b8ff5c961093cb8efe151fbcce724b451941ce20781de11a531ccd578" +dependencies = [ + "ethers-core", + "once_cell", + "serde", + "serde_json", +] + +[[package]] +name = "ethers-contract" +version = "2.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e066a0d9cfc70c454672bf16bb433b0243427420076dc5b2f49c448fb5a10628" +dependencies = [ + "ethers-contract-abigen", + "ethers-contract-derive", + "ethers-core", + "ethers-providers", + "futures-util", + "hex", + "once_cell", + "pin-project", + "serde", + "serde_json", + "thiserror", +] + +[[package]] +name = "ethers-contract-abigen" +version = "2.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c113e3e86b6bc16d98484b2c3bb2d01d6fed9f489fe2e592e5cc87c3024d616b" +dependencies = [ + "Inflector", + "dunce", + "ethers-core", + "eyre", + "hex", + "prettyplease", + "proc-macro2", + "quote", + "regex", + "serde", + "serde_json", + "syn 2.0.22", + "toml", + "walkdir", +] + +[[package]] +name = "ethers-contract-derive" +version = "2.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8c3fb5adee25701c79ec58fcf2c63594cd8829bc9ad6037ff862d5a111101ed2" +dependencies = [ + "Inflector", + "ethers-contract-abigen", + "ethers-core", + "hex", + "proc-macro2", + "quote", + "serde_json", + "syn 2.0.22", +] + +[[package]] +name = "ethers-core" +version = "2.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6da5fa198af0d3be20c19192df2bd9590b92ce09a8421e793bec8851270f1b05" +dependencies = [ + "arrayvec 0.7.4", + "bytes", + "cargo_metadata", + "chrono", + "elliptic-curve", + "ethabi", + "generic-array", + "hex", + "k256", + "num_enum", + "once_cell", + "open-fastrlp", + "rand 0.8.5", + "rlp", + "serde", + "serde_json", + "strum", + "syn 2.0.22", + "tempfile", + "thiserror", + "tiny-keccak", + "unicode-xid", +] + +[[package]] +name = "ethers-etherscan" +version = "2.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "84ebb401ba97c6f5af278c2c9936c4546cad75dec464b439ae6df249906f4caa" +dependencies = [ + "ethers-core", + "ethers-solc", + "reqwest", + "semver 1.0.17", + "serde", + "serde_json", + "thiserror", + "tracing", +] + +[[package]] +name = "ethers-middleware" +version = "2.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "740f4a773c19dd6d6a68c8c2e0996c096488d38997d524e21dc612c55da3bd24" +dependencies = [ + "async-trait", + "auto_impl", + "ethers-contract", + "ethers-core", + "ethers-etherscan", + "ethers-providers", + "ethers-signers", + "futures-channel", + "futures-locks", + "futures-util", + "instant", + "reqwest", + "serde", + "serde_json", + "thiserror", + "tokio", + "tracing", + "tracing-futures", + "url", +] + +[[package]] +name = "ethers-providers" +version = "2.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "56b498fd2a6c019d023e43e83488cd1fb0721f299055975aa6bac8dbf1e95f2c" +dependencies = [ + "async-trait", + "auto_impl", + "base64 0.21.2", + "bytes", + "enr", + "ethers-core", + "futures-core", + "futures-timer", + "futures-util", + "hashers", + "hex", + "http", + "instant", + "once_cell", + "pin-project", + "reqwest", + "serde", + "serde_json", + "thiserror", + "tokio", + "tracing", + "tracing-futures", + "url", + "wasm-bindgen", + "wasm-bindgen-futures", + "web-sys", + "ws_stream_wasm", +] + +[[package]] +name = "ethers-signers" +version = "2.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "02c4b7e15f212fa7cc2e1251868320221d4ff77a3d48068e69f47ce1c491df2d" +dependencies = [ + "async-trait", + "coins-bip32", + "coins-bip39", + "elliptic-curve", + "eth-keystore", + "ethers-core", + "hex", + "rand 0.8.5", + "sha2 0.10.7", + "thiserror", + "tracing", +] + +[[package]] +name = "ethers-solc" +version = "2.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a81c89f121595cf8959e746045bb8b25a6a38d72588561e1a3b7992fc213f674" +dependencies = [ + "cfg-if", + "dunce", + "ethers-core", + "glob", + "hex", + "home", + "md-5", + "num_cpus", + "once_cell", + "path-slash", + "rayon", + "regex", + "semver 1.0.17", + "serde", + "serde_json", + "solang-parser", + "thiserror", + "tiny-keccak", + "tokio", + "tracing", + "walkdir", + "yansi", +] + +[[package]] +name = "eyre" +version = "0.6.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4c2b6b5a29c02cdc822728b7d7b8ae1bab3e3b05d44522770ddd49722eeac7eb" +dependencies = [ + "indenter", + "once_cell", +] + +[[package]] +name = "ezkl" +version = "0.0.0" +dependencies = [ + "ark-std 0.3.0", + "benchy", + "bincode", + "clap 4.3.8", + "colored", + "colored_json", + "console_error_panic_hook", + "criterion", + "ecc", + "env_logger", + "ethers", + "gag", + "getrandom", + "halo2", + "halo2_gadgets", + "halo2_proofs", + "halo2_solidity_verifier", + "halo2curves 0.1.0", + "hex", + "indicatif", + "instant", + "itertools 0.10.5", + "lazy_static", + "log", + "mnist", + "num", + "openssl", + "pg_bigdecimal", + "plotters", + "postgres", + "pyo3", + "pyo3-asyncio", + "pyo3-log", + "rand 0.8.5", + "rayon", + "regex", + "reqwest", + "seq-macro", + "serde", + "serde-wasm-bindgen", + "serde_json", + "shellexpand", + "snark-verifier", + "tabled", + "tch", + "tempdir", + "tempfile", + "test-case", + "thiserror", + "tokio", + "tokio-util", + "tract-onnx", + "unzip-n", + "wasm-bindgen", + "wasm-bindgen-console-logger", + "wasm-bindgen-rayon", + "wasm-bindgen-test", +] + +[[package]] +name = "fallible-iterator" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4443176a9f2c162692bd3d352d745ef9413eec5782a80d8fd6f8a1ac692a07f7" + +[[package]] +name = "fastrand" +version = "1.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e51093e27b0797c359783294ca4f0a911c270184cb10f85783b118614a1501be" +dependencies = [ + "instant", +] + +[[package]] +name = "fastrlp" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "139834ddba373bbdd213dffe02c8d110508dcf1726c2be27e8d1f7d7e1856418" +dependencies = [ + "arrayvec 0.7.4", + "auto_impl", + "bytes", +] + +[[package]] +name = "ff" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ded41244b729663b1e574f1b4fb731469f69f79c17667b5d776b16cda0479449" +dependencies = [ + "bitvec 1.0.1", + "rand_core 0.6.4", + "subtle", +] + +[[package]] +name = "filedescriptor" +version = "0.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7199d965852c3bac31f779ef99cbb4537f80e952e2d6aa0ffeb30cce00f4f46e" +dependencies = [ + "libc", + "thiserror", + "winapi", +] + +[[package]] +name = "filetime" +version = "0.2.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5cbc844cecaee9d4443931972e1289c8ff485cb4cc2767cb03ca139ed6885153" +dependencies = [ + "cfg-if", + "libc", + "redox_syscall 0.2.16", + "windows-sys 0.48.0", +] + +[[package]] +name = "fixed-hash" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "835c052cb0c08c1acf6ffd71c022172e18723949c8282f2b9f27efbc51e64534" +dependencies = [ + "byteorder", + "rand 0.8.5", + "rustc-hex", + "static_assertions", +] + +[[package]] +name = "fixedbitset" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0ce7134b9999ecaf8bcd65542e436736ef32ddca1b3e06094cb6ec5755203b80" + +[[package]] +name = "flate2" +version = "1.0.26" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3b9429470923de8e8cbd4d2dc513535400b4b3fef0319fb5c4e1f520a7bef743" +dependencies = [ + "crc32fast", + "miniz_oxide", +] + +[[package]] +name = "fnv" +version = "1.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" + +[[package]] +name = "foreign-types" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f6f339eb8adc052cd2ca78910fda869aefa38d22d5cb648e6485e4d3fc06f3b1" +dependencies = [ + "foreign-types-shared", +] + +[[package]] +name = "foreign-types-shared" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "00b0228411908ca8685dba7fc2cdd70ec9990a6e753e89b6ac91a84c40fbaf4b" + +[[package]] +name = "form_urlencoded" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a62bc1cf6f830c2ec14a513a9fb124d0a213a629668a4186f329db21fe045652" +dependencies = [ + "percent-encoding", +] + +[[package]] +name = "fuchsia-cprng" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a06f77d526c1a601b7c4cdd98f54b5eaabffc14d5f2f0296febdc7f357c6d3ba" + +[[package]] +name = "funty" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6d5a32815ae3f33302d95fdcb2ce17862f8c65363dcfd29360480ba1001fc9c" + +[[package]] +name = "futures" +version = "0.3.28" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "23342abe12aba583913b2e62f22225ff9c950774065e4bfb61a19cd9770fec40" +dependencies = [ + "futures-channel", + "futures-core", + "futures-executor", + "futures-io", + "futures-sink", + "futures-task", + "futures-util", +] + +[[package]] +name = "futures-channel" +version = "0.3.28" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "955518d47e09b25bbebc7a18df10b81f0c766eaf4c4f1cccef2fca5f2a4fb5f2" +dependencies = [ + "futures-core", + "futures-sink", +] + +[[package]] +name = "futures-core" +version = "0.3.28" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4bca583b7e26f571124fe5b7561d49cb2868d79116cfa0eefce955557c6fee8c" + +[[package]] +name = "futures-executor" +version = "0.3.28" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ccecee823288125bd88b4d7f565c9e58e41858e47ab72e8ea2d64e93624386e0" +dependencies = [ + "futures-core", + "futures-task", + "futures-util", +] + +[[package]] +name = "futures-io" +version = "0.3.28" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4fff74096e71ed47f8e023204cfd0aa1289cd54ae5430a9523be060cdb849964" + +[[package]] +name = "futures-locks" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "45ec6fe3675af967e67c5536c0b9d44e34e6c52f86bedc4ea49c5317b8e94d06" +dependencies = [ + "futures-channel", + "futures-task", +] + +[[package]] +name = "futures-macro" +version = "0.3.28" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "89ca545a94061b6365f2c7355b4b32bd20df3ff95f02da9329b34ccc3bd6ee72" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.22", +] + +[[package]] +name = "futures-sink" +version = "0.3.28" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f43be4fe21a13b9781a69afa4985b0f6ee0e1afab2c6f454a8cf30e2b2237b6e" + +[[package]] +name = "futures-task" +version = "0.3.28" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "76d3d132be6c0e6aa1534069c705a74a5997a356c0dc2f86a47765e5617c5b65" + +[[package]] +name = "futures-timer" +version = "3.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e64b03909df88034c26dc1547e8970b91f98bdb65165d6a4e9110d94263dbb2c" +dependencies = [ + "gloo-timers", + "send_wrapper 0.4.0", +] + +[[package]] +name = "futures-util" +version = "0.3.28" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "26b01e40b772d54cf6c6d721c1d1abd0647a0106a12ecaa1c186273392a69533" +dependencies = [ + "futures-channel", + "futures-core", + "futures-io", + "futures-macro", + "futures-sink", + "futures-task", + "memchr", + "pin-project-lite", + "pin-utils", + "slab", +] + +[[package]] +name = "fxhash" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c31b6d751ae2c7f11320402d34e41349dd1016f8d5d45e48c4312bc8625af50c" +dependencies = [ + "byteorder", +] + +[[package]] +name = "gag" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a713bee13966e9fbffdf7193af71d54a6b35a0bb34997cd6c9519ebeb5005972" +dependencies = [ + "filedescriptor", + "tempfile", +] + +[[package]] +name = "generic-array" +version = "0.14.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85649ca51fd72272d7821adaf274ad91c288277713d9c18820d8499a7ff69e9a" +dependencies = [ + "typenum", + "version_check", + "zeroize", +] + +[[package]] +name = "getrandom" +version = "0.2.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "be4136b2a15dd319360be1c07d9933517ccf0be8f16bf62a3bee4f0d618df427" +dependencies = [ + "cfg-if", + "js-sys", + "libc", + "wasi", + "wasm-bindgen", +] + +[[package]] +name = "glob" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d2fabcfbdc87f4758337ca535fb41a6d701b65693ce38287d856d1674551ec9b" + +[[package]] +name = "gloo-timers" +version = "0.2.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9b995a66bb87bebce9a0f4a95aed01daca4872c050bfcb21653361c03bc35e5c" +dependencies = [ + "futures-channel", + "futures-core", + "js-sys", + "wasm-bindgen", +] + +[[package]] +name = "group" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f0f9ef7462f7c099f518d754361858f86d8a07af53ba9af0fe635bbccb151a63" +dependencies = [ + "ff", + "rand_core 0.6.4", + "subtle", +] + +[[package]] +name = "h2" +version = "0.3.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "97ec8491ebaf99c8eaa73058b045fe58073cd6be7f596ac993ced0b0a0c01049" +dependencies = [ + "bytes", + "fnv", + "futures-core", + "futures-sink", + "futures-util", + "http", + "indexmap 1.9.3", + "slab", + "tokio", + "tokio-util", + "tracing", +] + +[[package]] +name = "half" +version = "1.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eabb4a44450da02c90444cf74558da904edde8fb4e9035a9a6a4e15445af0bd7" + +[[package]] +name = "half" +version = "2.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "02b4af3693f1b705df946e9fe5631932443781d0aabb423b62fcd4d73f6d2fd0" +dependencies = [ + "crunchy", + "num-traits", +] + +[[package]] +name = "halo2" +version = "0.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f7e3ae7c898be32ce6dd6621cf4a70b504b551adc4ff43bd08a7197a5ac1f308" + +[[package]] +name = "halo2_gadgets" +version = "0.2.0" +source = "git+https://github.com/zkonduit/halo2?branch=ac/lookup-modularity#57b9123835aa7d8482f4182ede3e8f4b0aea5c0a" +dependencies = [ + "arrayvec 0.7.4", + "bitvec 1.0.1", + "ff", + "group", + "halo2_proofs", + "halo2curves 0.1.0", + "lazy_static", + "rand 0.8.5", + "subtle", + "uint", +] + +[[package]] +name = "halo2_proofs" +version = "0.2.0" +source = "git+https://github.com/zkonduit/halo2?branch=ac/lookup-modularity#57b9123835aa7d8482f4182ede3e8f4b0aea5c0a" +dependencies = [ + "blake2b_simd", + "env_logger", + "ff", + "group", + "halo2curves 0.1.0", + "icicle", + "log", + "maybe-rayon", + "plotters", + "rand_chacha", + "rand_core 0.6.4", + "rustacuda", + "sha3 0.9.1", + "tabbycat", + "tracing", +] + +[[package]] +name = "halo2_solidity_verifier" +version = "0.1.0" +source = "git+https://github.com/alexander-camuto/halo2-solidity-verifier?branch=ac/lookup-modularity#cf9a3128bb583680dd4c418defd8d37bd8e5c3f1" +dependencies = [ + "askama", + "blake2b_simd", + "halo2_proofs", + "hex", + "itertools 0.11.0", + "ruint", + "sha3 0.10.8", +] + +[[package]] +name = "halo2curves" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6b1142bd1059aacde1b477e0c80c142910f1ceae67fc619311d6a17428007ab" +dependencies = [ + "blake2b_simd", + "ff", + "group", + "lazy_static", + "num-bigint", + "num-traits", + "pasta_curves", + "paste", + "rand 0.8.5", + "rand_core 0.6.4", + "serde", + "serde_arrays", + "static_assertions", + "subtle", +] + +[[package]] +name = "halo2curves" +version = "0.3.2" +source = "git+https://github.com/privacy-scaling-explorations/halo2curves?tag=0.3.2#9f5c50810bbefe779ee5cf1d852b2fe85dc35d5e" +dependencies = [ + "ff", + "group", + "lazy_static", + "num-bigint", + "num-traits", + "pasta_curves", + "paste", + "rand 0.8.5", + "rand_core 0.6.4", + "static_assertions", + "subtle", +] + +[[package]] +name = "halo2wrong" +version = "0.1.0" +source = "git+https://github.com/zkonduit/halo2wrong?branch=ac/chunked-mv-lookup#c1d7551c82953829caee30fe218759b0d2657d26" +dependencies = [ + "halo2_proofs", + "num-bigint", + "num-integer", + "num-traits", +] + +[[package]] +name = "hashbrown" +version = "0.11.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ab5ef0d4909ef3724cc8cce6ccc8572c5c817592e9285f5464f8e86f8bd3726e" +dependencies = [ + "ahash 0.7.6", +] + +[[package]] +name = "hashbrown" +version = "0.12.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888" + +[[package]] +name = "hashbrown" +version = "0.13.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "43a3c133739dddd0d2990f9a4bdf8eb4b21ef50e4851ca85ab661199821d510e" +dependencies = [ + "ahash 0.8.3", +] + +[[package]] +name = "hashbrown" +version = "0.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2c6201b9ff9fd90a5a3bac2e56a830d0caa509576f0e503818ee82c181b3437a" + +[[package]] +name = "hashers" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b2bca93b15ea5a746f220e56587f71e73c6165eab783df9e26590069953e3c30" +dependencies = [ + "fxhash", +] + +[[package]] +name = "heck" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8" + +[[package]] +name = "hermit-abi" +version = "0.1.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "62b467343b94ba476dcb2500d242dadbb39557df889310ac77c5d99100aaac33" +dependencies = [ + "libc", +] + +[[package]] +name = "hermit-abi" +version = "0.2.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ee512640fe35acbfb4bb779db6f0d80704c2cacfa2e39b601ef3e3f47d1ae4c7" +dependencies = [ + "libc", +] + +[[package]] +name = "hermit-abi" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fed44880c466736ef9a5c5b5facefb5ed0785676d0c02d612db14e54f0d84286" + +[[package]] +name = "hex" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70" + +[[package]] +name = "hex-literal" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6fe2267d4ed49bc07b63801559be28c718ea06c4738b7a03c94df7386d2cde46" + +[[package]] +name = "hmac" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c49c37c09c17a53d937dfbb742eb3a961d65a994e6bcdcf37e7399d0cc8ab5e" +dependencies = [ + "digest 0.10.7", +] + +[[package]] +name = "home" +version = "0.5.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5444c27eef6923071f7ebcc33e3444508466a76f7a2b93da00ed6e19f30c1ddb" +dependencies = [ + "windows-sys 0.48.0", +] + +[[package]] +name = "http" +version = "0.2.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bd6effc99afb63425aff9b05836f029929e345a6148a14b7ecd5ab67af944482" +dependencies = [ + "bytes", + "fnv", + "itoa", +] + +[[package]] +name = "http-body" +version = "0.4.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d5f38f16d184e36f2408a55281cd658ecbd3ca05cce6d6510a176eca393e26d1" +dependencies = [ + "bytes", + "http", + "pin-project-lite", +] + +[[package]] +name = "httparse" +version = "1.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d897f394bad6a705d5f4104762e116a75639e470d80901eed05a860a95cb1904" + +[[package]] +name = "httpdate" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c4a1e36c821dbe04574f602848a19f742f4fb3c98d40449f11bcad18d6b17421" + +[[package]] +name = "humantime" +version = "2.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9a3a5bfb195931eeb336b2a7b4d761daec841b97f947d34394601737a7bba5e4" + +[[package]] +name = "hyper" +version = "0.14.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ffb1cfd654a8219eaef89881fdb3bb3b1cdc5fa75ded05d6933b2b382e395468" +dependencies = [ + "bytes", + "futures-channel", + "futures-core", + "futures-util", + "h2", + "http", + "http-body", + "httparse", + "httpdate", + "itoa", + "pin-project-lite", + "socket2 0.4.9", + "tokio", + "tower-service", + "tracing", + "want", +] + +[[package]] +name = "hyper-tls" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d6183ddfa99b85da61a140bea0efc93fdf56ceaa041b37d553518030827f9905" +dependencies = [ + "bytes", + "hyper", + "native-tls", + "tokio", + "tokio-native-tls", +] + +[[package]] +name = "icicle" +version = "0.1.0" +source = "git+https://github.com/ingonyama-zk/icicle.git?branch=rust/large-bucket-factor-msm#2820f43746263fbae7b30a29e08099ea19383c64" +dependencies = [ + "ark-bls12-377", + "ark-bls12-381", + "ark-bn254", + "ark-ec", + "ark-ff 0.3.0", + "ark-poly", + "ark-std 0.3.0", + "cc", + "hex", + "rand 0.8.5", + "rustacuda", + "rustacuda_core", + "rustacuda_derive", + "serde", + "serde_cbor", + "serde_derive", +] + +[[package]] +name = "ident_case" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b9e0384b61958566e926dc50660321d12159025e767c18e043daf26b70104c39" + +[[package]] +name = "idna" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7d20d6b07bfbc108882d88ed8e37d39636dcc260e15e30c45e6ba089610b917c" +dependencies = [ + "unicode-bidi", + "unicode-normalization", +] + +[[package]] +name = "impl-codec" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba6a270039626615617f3f36d15fc827041df3b78c439da2cadfa47455a77f2f" +dependencies = [ + "parity-scale-codec", +] + +[[package]] +name = "impl-rlp" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f28220f89297a075ddc7245cd538076ee98b01f2a9c23a53a4f1105d5a322808" +dependencies = [ + "rlp", +] + +[[package]] +name = "impl-serde" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ebc88fc67028ae3db0c853baa36269d398d5f45b6982f95549ff5def78c935cd" +dependencies = [ + "serde", +] + +[[package]] +name = "impl-trait-for-tuples" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "11d7a9f6330b71fea57921c9b61c47ee6e84f72d394754eff6163ae67e7395eb" +dependencies = [ + "proc-macro2", + "quote", + "syn 1.0.109", +] + +[[package]] +name = "indenter" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ce23b50ad8242c51a442f3ff322d56b02f08852c77e4c0b4d3fd684abc89c683" + +[[package]] +name = "indexmap" +version = "1.9.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bd070e393353796e801d209ad339e89596eb4c8d430d18ede6a1cced8fafbd99" +dependencies = [ + "autocfg", + "hashbrown 0.12.3", +] + +[[package]] +name = "indexmap" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d5477fe2230a79769d8dc68e0eabf5437907c0457a5614a9e8dddb67f65eb65d" +dependencies = [ + "equivalent", + "hashbrown 0.14.0", +] + +[[package]] +name = "indicatif" +version = "0.17.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8ff8cc23a7393a397ed1d7f56e6365cba772aba9f9912ab968b03043c395d057" +dependencies = [ + "console", + "instant", + "number_prefix", + "portable-atomic", + "rayon", + "unicode-width", +] + +[[package]] +name = "indoc" +version = "1.0.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bfa799dd5ed20a7e349f3b4639aa80d74549c81716d9ec4f994c9b5815598306" + +[[package]] +name = "inout" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a0c10553d664a4d0bcff9f4215d0aac67a639cc68ef660840afe309b807bc9f5" +dependencies = [ + "generic-array", +] + +[[package]] +name = "instant" +version = "0.1.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7a5bbe824c507c5da5956355e86a746d82e0e1464f65d862cc5e71da70e94b2c" +dependencies = [ + "cfg-if", + "js-sys", + "wasm-bindgen", + "web-sys", +] + +[[package]] +name = "integer" +version = "0.1.0" +source = "git+https://github.com/zkonduit/halo2wrong?branch=ac/chunked-mv-lookup#c1d7551c82953829caee30fe218759b0d2657d26" +dependencies = [ + "maingate", + "num-bigint", + "num-integer", + "num-traits", + "rand 0.8.5", + "subtle", +] + +[[package]] +name = "io-lifetimes" +version = "1.0.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eae7b9aee968036d54dce06cebaefd919e4472e753296daccd6d344e3e2df0c2" +dependencies = [ + "hermit-abi 0.3.1", + "libc", + "windows-sys 0.48.0", +] + +[[package]] +name = "ipnet" +version = "2.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "28b29a3cd74f0f4598934efe3aeba42bae0eb4680554128851ebbecb02af14e6" + +[[package]] +name = "is-terminal" +version = "0.4.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "adcf93614601c8129ddf72e2d5633df827ba6551541c6d8c59520a371475be1f" +dependencies = [ + "hermit-abi 0.3.1", + "io-lifetimes", + "rustix", + "windows-sys 0.48.0", +] + +[[package]] +name = "itertools" +version = "0.10.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b0fd2260e829bddf4cb6ea802289de2f86d6a7a690192fbe91b3f46e0f2c8473" +dependencies = [ + "either", +] + +[[package]] +name = "itertools" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b1c173a5686ce8bfa551b3563d0c2170bf24ca44da99c7ca4bfdab5418c3fe57" +dependencies = [ + "either", +] + +[[package]] +name = "itoa" +version = "1.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "453ad9f582a441959e5f0d088b02ce04cfe8d51a8eaf077f12ac6d3e94164ca6" + +[[package]] +name = "jobserver" +version = "0.1.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8c37f63953c4c63420ed5fd3d6d398c719489b9f872b9fa683262f8edd363c7d" +dependencies = [ + "libc", +] + +[[package]] +name = "js-sys" +version = "0.3.64" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c5f195fe497f702db0f318b07fdd68edb16955aed830df8363d837542f8f935a" +dependencies = [ + "wasm-bindgen", +] + +[[package]] +name = "k256" +version = "0.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cadb76004ed8e97623117f3df85b17aaa6626ab0b0831e6573f104df16cd1bcc" +dependencies = [ + "cfg-if", + "ecdsa", + "elliptic-curve", + "once_cell", + "sha2 0.10.7", + "signature", +] + +[[package]] +name = "keccak" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f6d5ed8676d904364de097082f4e7d240b571b67989ced0240f08b7f966f940" +dependencies = [ + "cpufeatures", +] + +[[package]] +name = "kstring" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec3066350882a1cd6d950d055997f379ac37fd39f81cd4d8ed186032eb3c5747" +dependencies = [ + "serde", + "static_assertions", +] + +[[package]] +name = "lalrpop" +version = "0.19.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0a1cbf952127589f2851ab2046af368fd20645491bb4b376f04b7f94d7a9837b" +dependencies = [ + "ascii-canvas", + "bit-set", + "diff", + "ena", + "is-terminal", + "itertools 0.10.5", + "lalrpop-util", + "petgraph", + "regex", + "regex-syntax 0.6.29", + "string_cache", + "term", + "tiny-keccak", + "unicode-xid", +] + +[[package]] +name = "lalrpop-util" +version = "0.19.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d3c48237b9604c5a4702de6b824e02006c3214327564636aef27c1028a8fa0ed" + +[[package]] +name = "lazy_static" +version = "1.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" +dependencies = [ + "spin", +] + +[[package]] +name = "libc" +version = "0.2.147" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b4668fb0ea861c1df094127ac5f1da3409a82116a4ba74fca2e58ef927159bb3" + +[[package]] +name = "libm" +version = "0.2.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f7012b1bbb0719e1097c47611d3898568c546d597c2e74d66f6087edd5233ff4" + +[[package]] +name = "linux-raw-sys" +version = "0.3.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ef53942eb7bf7ff43a617b3e2c1c4a5ecf5944a7c1bc12d7ee39bbb15e5c1519" + +[[package]] +name = "liquid" +version = "0.26.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "69f68ae1011499ae2ef879f631891f21c78e309755f4a5e483c4a8f12e10b609" +dependencies = [ + "doc-comment", + "liquid-core", + "liquid-derive", + "liquid-lib", + "serde", +] + +[[package]] +name = "liquid-core" +version = "0.26.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "79e0724dfcaad5cfb7965ea0f178ca0870b8d7315178f4a7179f5696f7f04d5f" +dependencies = [ + "anymap2", + "itertools 0.10.5", + "kstring", + "liquid-derive", + "num-traits", + "pest", + "pest_derive", + "regex", + "serde", + "time", +] + +[[package]] +name = "liquid-derive" +version = "0.26.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fc2fb41a9bb4257a3803154bdf7e2df7d45197d1941c9b1a90ad815231630721" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.22", +] + +[[package]] +name = "liquid-lib" +version = "0.26.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e2a17e273a6fb1fb6268f7a5867ddfd0bd4683c7e19b51084f3d567fad4348c0" +dependencies = [ + "itertools 0.10.5", + "liquid-core", + "once_cell", + "percent-encoding", + "regex", + "time", + "unicode-segmentation", +] + +[[package]] +name = "lock_api" +version = "0.4.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c1cc9717a20b1bb222f333e6a92fd32f7d8a18ddc5a3191a11af45dcbf4dcd16" +dependencies = [ + "autocfg", + "scopeguard", +] + +[[package]] +name = "log" +version = "0.4.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b06a4cde4c0f271a446782e3eff8de789548ce57dbc8eca9292c27f4a42004b4" + +[[package]] +name = "maingate" +version = "0.1.0" +source = "git+https://github.com/zkonduit/halo2wrong?branch=ac/chunked-mv-lookup#c1d7551c82953829caee30fe218759b0d2657d26" +dependencies = [ + "halo2wrong", + "num-bigint", + "num-integer", + "num-traits", + "rand 0.8.5", + "subtle", +] + +[[package]] +name = "maplit" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3e2e65a1a2e43cfcb47a895c4c8b10d1f4a61097f9f254f183aee60cad9c651d" + +[[package]] +name = "matrixmultiply" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "090126dc04f95dc0d1c1c91f61bdd474b3930ca064c1edc8a849da2c6cbe1e77" +dependencies = [ + "autocfg", + "rawpointer", +] + +[[package]] +name = "maybe-rayon" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8ea1f30cedd69f0a2954655f7188c6a834246d2bcf1e315e2ac40c4b24dc9519" +dependencies = [ + "cfg-if", + "rayon", +] + +[[package]] +name = "md-5" +version = "0.10.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6365506850d44bff6e2fbcb5176cf63650e48bd45ef2fe2665ae1570e0f4b9ca" +dependencies = [ + "digest 0.10.7", +] + +[[package]] +name = "memchr" +version = "2.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2dffe52ecf27772e601905b7522cb4ef790d2cc203488bbd0e2fe85fcb74566d" + +[[package]] +name = "memmap2" +version = "0.5.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "83faa42c0a078c393f6b29d5db232d8be22776a891f8f56e5284faee4a20b327" +dependencies = [ + "libc", +] + +[[package]] +name = "memoffset" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d61c719bcfbcf5d62b3a09efa6088de8c54bc0bfcd3ea7ae39fcc186108b8de1" +dependencies = [ + "autocfg", +] + +[[package]] +name = "memoffset" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a634b1c61a95585bd15607c6ab0c4e5b226e695ff2800ba0cdccddf208c406c" +dependencies = [ + "autocfg", +] + +[[package]] +name = "memory-stats" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34f79cf9964c5c9545493acda1263f1912f8d2c56c8a2ffee2606cb960acaacc" +dependencies = [ + "libc", + "winapi", +] + +[[package]] +name = "mime" +version = "0.3.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a" + +[[package]] +name = "mime_guess" +version = "2.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4192263c238a5f0d0c6bfd21f336a313a4ce1c450542449ca191bb657b4642ef" +dependencies = [ + "mime", + "unicase", +] + +[[package]] +name = "minimal-lexical" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a" + +[[package]] +name = "miniz_oxide" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e7810e0be55b428ada41041c41f32c9f1a42817901b4ccf45fa3d4b6561e74c7" +dependencies = [ + "adler", +] + +[[package]] +name = "mio" +version = "0.8.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "927a765cd3fc26206e66b296465fa9d3e5ab003e651c1b3c060e7956d96b19d2" +dependencies = [ + "libc", + "wasi", + "windows-sys 0.48.0", +] + +[[package]] +name = "mnist" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2126097226757059ceb317a872fb9e3b289e67af0ba2681b1e1cf389b763c260" +dependencies = [ + "byteorder", +] + +[[package]] +name = "native-tls" +version = "0.2.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "07226173c32f2926027b63cce4bcd8076c3552846cbe7925f3aaffeac0a3b92e" +dependencies = [ + "lazy_static", + "libc", + "log", + "openssl", + "openssl-probe", + "openssl-sys", + "schannel", + "security-framework", + "security-framework-sys", + "tempfile", +] + +[[package]] +name = "ndarray" +version = "0.15.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "adb12d4e967ec485a5f71c6311fe28158e9d6f4bc4a447b474184d0f91a8fa32" +dependencies = [ + "matrixmultiply", + "num-complex", + "num-integer", + "num-traits", + "rawpointer", +] + +[[package]] +name = "new_debug_unreachable" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e4a24736216ec316047a1fc4252e27dabb04218aa4a3f37c6e7ddbf1f9782b54" + +[[package]] +name = "nix" +version = "0.27.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2eb04e9c688eff1c89d72b407f168cf79bb9e867a9d3323ed6c01519eb9cc053" +dependencies = [ + "bitflags 2.4.0", + "cfg-if", + "libc", +] + +[[package]] +name = "nom" +version = "7.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d273983c5a657a70a3e8f2a01329822f3b8c8172b73826411a55751e404a0a4a" +dependencies = [ + "memchr", + "minimal-lexical", +] + +[[package]] +name = "num" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b05180d69e3da0e530ba2a1dae5110317e49e3b7f3d41be227dc5f92e49ee7af" +dependencies = [ + "num-bigint", + "num-complex", + "num-integer", + "num-iter", + "num-rational", + "num-traits", +] + +[[package]] +name = "num-bigint" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f93ab6289c7b344a8a9f60f88d80aa20032336fe78da341afc91c8a2341fc75f" +dependencies = [ + "autocfg", + "num-integer", + "num-traits", + "rand 0.8.5", +] + +[[package]] +name = "num-complex" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "02e0d21255c828d6f128a1e41534206671e8c3ea0c62f32291e808dc82cff17d" +dependencies = [ + "num-traits", +] + +[[package]] +name = "num-integer" +version = "0.1.45" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "225d3389fb3509a24c93f5c29eb6bde2586b98d9f016636dff58d7c6f7569cd9" +dependencies = [ + "autocfg", + "num-traits", +] + +[[package]] +name = "num-iter" +version = "0.1.43" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7d03e6c028c5dc5cac6e2dec0efda81fc887605bb3d884578bb6d6bf7514e252" +dependencies = [ + "autocfg", + "num-integer", + "num-traits", +] + +[[package]] +name = "num-rational" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0638a1c9d0a3c0914158145bc76cff373a75a627e6ecbfb71cbe6f453a5a19b0" +dependencies = [ + "autocfg", + "num-bigint", + "num-integer", + "num-traits", +] + +[[package]] +name = "num-traits" +version = "0.2.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "578ede34cf02f8924ab9447f50c28075b4d3e5b269972345e7e0372b38c6cdcd" +dependencies = [ + "autocfg", + "libm", +] + +[[package]] +name = "num_cpus" +version = "1.15.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0fac9e2da13b5eb447a6ce3d392f23a29d8694bff781bf03a16cd9ac8697593b" +dependencies = [ + "hermit-abi 0.2.6", + "libc", +] + +[[package]] +name = "num_enum" +version = "0.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7a015b430d3c108a207fd776d2e2196aaf8b1cf8cf93253e3a097ff3085076a1" +dependencies = [ + "num_enum_derive", +] + +[[package]] +name = "num_enum_derive" +version = "0.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "96667db765a921f7b295ffee8b60472b686a51d4f21c2ee4ffdb94c7013b65a6" +dependencies = [ + "proc-macro-crate", + "proc-macro2", + "quote", + "syn 2.0.22", +] + +[[package]] +name = "number_prefix" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "830b246a0e5f20af87141b25c173cd1b609bd7779a4617d6ec582abaf90870f3" + +[[package]] +name = "once_cell" +version = "1.18.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dd8b5dd2ae5ed71462c540258bedcb51965123ad7e7ccf4b9a8cafaa4a63576d" + +[[package]] +name = "oorandom" +version = "11.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0ab1bc2a289d34bd04a330323ac98a1b4bc82c9d9fcb1e66b63caa84da26b575" + +[[package]] +name = "opaque-debug" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "624a8340c38c1b80fd549087862da4ba43e08858af025b236e509b6649fc13d5" + +[[package]] +name = "open-fastrlp" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "786393f80485445794f6043fd3138854dd109cc6c4bd1a6383db304c9ce9b9ce" +dependencies = [ + "arrayvec 0.7.4", + "auto_impl", + "bytes", + "ethereum-types", + "open-fastrlp-derive", +] + +[[package]] +name = "open-fastrlp-derive" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "003b2be5c6c53c1cfeb0a238b8a1c3915cd410feb684457a36c10038f764bb1c" +dependencies = [ + "bytes", + "proc-macro2", + "quote", + "syn 1.0.109", +] + +[[package]] +name = "openssl" +version = "0.10.55" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "345df152bc43501c5eb9e4654ff05f794effb78d4efe3d53abc158baddc0703d" +dependencies = [ + "bitflags 1.3.2", + "cfg-if", + "foreign-types", + "libc", + "once_cell", + "openssl-macros", + "openssl-sys", +] + +[[package]] +name = "openssl-macros" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.22", +] + +[[package]] +name = "openssl-probe" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ff011a302c396a5197692431fc1948019154afc178baf7d8e37367442a4601cf" + +[[package]] +name = "openssl-src" +version = "111.26.0+1.1.1u" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "efc62c9f12b22b8f5208c23a7200a442b2e5999f8bdf80233852122b5a4f6f37" +dependencies = [ + "cc", +] + +[[package]] +name = "openssl-sys" +version = "0.9.90" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "374533b0e45f3a7ced10fcaeccca020e66656bc03dac384f852e4e5a7a8104a6" +dependencies = [ + "cc", + "libc", + "openssl-src", + "pkg-config", + "vcpkg", +] + +[[package]] +name = "option-ext" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "04744f49eae99ab78e0d5c0b603ab218f515ea8cfe5a456d7629ad883a3b6e7d" + +[[package]] +name = "papergrid" +version = "0.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ae7891b22598926e4398790c8fe6447930c72a67d36d983a49d6ce682ce83290" +dependencies = [ + "ansi-str", + "ansitok", + "bytecount", + "fnv", + "unicode-width", +] + +[[package]] +name = "parity-scale-codec" +version = "3.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2287753623c76f953acd29d15d8100bcab84d29db78fb6f352adb3c53e83b967" +dependencies = [ + "arrayvec 0.7.4", + "bitvec 1.0.1", + "byte-slice-cast", + "impl-trait-for-tuples", + "parity-scale-codec-derive", + "serde", +] + +[[package]] +name = "parity-scale-codec-derive" +version = "3.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2b6937b5e67bfba3351b87b040d48352a2fcb6ad72f81855412ce97b45c8f110" +dependencies = [ + "proc-macro-crate", + "proc-macro2", + "quote", + "syn 1.0.109", +] + +[[package]] +name = "parking_lot" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3742b2c103b9f06bc9fff0a37ff4912935851bee6d36f3c02bcc755bcfec228f" +dependencies = [ + "lock_api", + "parking_lot_core", +] + +[[package]] +name = "parking_lot_core" +version = "0.9.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "93f00c865fe7cabf650081affecd3871070f26767e7b2070a3ffae14c654b447" +dependencies = [ + "cfg-if", + "libc", + "redox_syscall 0.3.5", + "smallvec", + "windows-targets 0.48.0", +] + +[[package]] +name = "password-hash" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7676374caaee8a325c9e7a2ae557f216c5563a171d6997b0ef8a65af35147700" +dependencies = [ + "base64ct", + "rand_core 0.6.4", + "subtle", +] + +[[package]] +name = "pasta_curves" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d3e57598f73cc7e1b2ac63c79c517b31a0877cd7c402cdcaa311b5208de7a095" +dependencies = [ + "blake2b_simd", + "ff", + "group", + "lazy_static", + "rand 0.8.5", + "static_assertions", + "subtle", +] + +[[package]] +name = "paste" +version = "1.0.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9f746c4065a8fa3fe23974dd82f15431cc8d40779821001404d10d2e79ca7d79" + +[[package]] +name = "path-slash" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e91099d4268b0e11973f036e885d652fb0b21fedcf69738c627f94db6a44f42" + +[[package]] +name = "pbkdf2" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "83a0692ec44e4cf1ef28ca317f14f8f07da2d95ec3fa01f86e4467b725e60917" +dependencies = [ + "digest 0.10.7", + "hmac", + "password-hash", + "sha2 0.10.7", +] + +[[package]] +name = "pbkdf2" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f0ca0b5a68607598bf3bad68f32227a8164f6254833f84eafaac409cd6746c31" +dependencies = [ + "digest 0.10.7", + "hmac", +] + +[[package]] +name = "percent-encoding" +version = "2.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9b2a4787296e9989611394c33f193f676704af1686e70b8f8033ab5ba9a35a94" + +[[package]] +name = "pest" +version = "2.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f73935e4d55e2abf7f130186537b19e7a4abc886a0252380b59248af473a3fc9" +dependencies = [ + "thiserror", + "ucd-trie", +] + +[[package]] +name = "pest_derive" +version = "2.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "aef623c9bbfa0eedf5a0efba11a5ee83209c326653ca31ff019bec3a95bfff2b" +dependencies = [ + "pest", + "pest_generator", +] + +[[package]] +name = "pest_generator" +version = "2.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b3e8cba4ec22bada7fc55ffe51e2deb6a0e0db2d0b7ab0b103acc80d2510c190" +dependencies = [ + "pest", + "pest_meta", + "proc-macro2", + "quote", + "syn 2.0.22", +] + +[[package]] +name = "pest_meta" +version = "2.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a01f71cb40bd8bb94232df14b946909e14660e33fc05db3e50ae2a82d7ea0ca0" +dependencies = [ + "once_cell", + "pest", + "sha2 0.10.7", +] + +[[package]] +name = "petgraph" +version = "0.6.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4dd7d28ee937e54fe3080c91faa1c3a46c06de6252988a7f4592ba2310ef22a4" +dependencies = [ + "fixedbitset", + "indexmap 1.9.3", +] + +[[package]] +name = "pg_bigdecimal" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f9855a94c74528af62c0ea236577af5e601263c1c404a6ac939b07c97c8e0216" +dependencies = [ + "bigdecimal", + "byteorder", + "bytes", + "num", + "postgres", +] + +[[package]] +name = "pharos" +version = "0.5.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e9567389417feee6ce15dd6527a8a1ecac205ef62c2932bcf3d9f6fc5b78b414" +dependencies = [ + "futures", + "rustc_version 0.4.0", +] + +[[package]] +name = "phf" +version = "0.11.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ade2d8b8f33c7333b51bcf0428d37e217e9f32192ae4772156f65063b8ce03dc" +dependencies = [ + "phf_macros", + "phf_shared 0.11.2", +] + +[[package]] +name = "phf_generator" +version = "0.11.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "48e4cc64c2ad9ebe670cb8fd69dd50ae301650392e81c05f9bfcb2d5bdbc24b0" +dependencies = [ + "phf_shared 0.11.2", + "rand 0.8.5", +] + +[[package]] +name = "phf_macros" +version = "0.11.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3444646e286606587e49f3bcf1679b8cef1dc2c5ecc29ddacaffc305180d464b" +dependencies = [ + "phf_generator", + "phf_shared 0.11.2", + "proc-macro2", + "quote", + "syn 2.0.22", +] + +[[package]] +name = "phf_shared" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6796ad771acdc0123d2a88dc428b5e38ef24456743ddb1744ed628f9815c096" +dependencies = [ + "siphasher", +] + +[[package]] +name = "phf_shared" +version = "0.11.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "90fcb95eef784c2ac79119d1dd819e162b5da872ce6f3c3abe1e8ca1c082f72b" +dependencies = [ + "siphasher", +] + +[[package]] +name = "pin-project" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c95a7476719eab1e366eaf73d0260af3021184f18177925b07f54b30089ceead" +dependencies = [ + "pin-project-internal", +] + +[[package]] +name = "pin-project-internal" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "39407670928234ebc5e6e580247dd567ad73a3578460c5990f9503df207e8f07" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.22", +] + +[[package]] +name = "pin-project-lite" +version = "0.2.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8afb450f006bf6385ca15ef45d71d2288452bc3683ce2e2cacc0d18e4be60b58" + +[[package]] +name = "pin-utils" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" + +[[package]] +name = "pkcs8" +version = "0.10.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f950b2377845cebe5cf8b5165cb3cc1a5e0fa5cfa3e1f7f55707d8fd82e0a7b7" +dependencies = [ + "der", + "spki", +] + +[[package]] +name = "pkg-config" +version = "0.3.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "26072860ba924cbfa98ea39c8c19b4dd6a4a25423dbdf219c1eca91aa0cf6964" + +[[package]] +name = "plotters" +version = "0.3.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d2c224ba00d7cadd4d5c660deaf2098e5e80e07846537c51f9cfa4be50c1fd45" +dependencies = [ + "num-traits", + "plotters-backend", + "plotters-svg", + "wasm-bindgen", + "web-sys", +] + +[[package]] +name = "plotters-backend" +version = "0.3.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9e76628b4d3a7581389a35d5b6e2139607ad7c75b17aed325f210aa91f4a9609" + +[[package]] +name = "plotters-svg" +version = "0.3.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "38f6d39893cca0701371e3c27294f09797214b86f1fb951b89ade8ec04e2abab" +dependencies = [ + "plotters-backend", +] + +[[package]] +name = "portable-atomic" +version = "1.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "767eb9f07d4a5ebcb39bbf2d452058a93c011373abf6832e24194a1c3f004794" + +[[package]] +name = "poseidon" +version = "0.2.0" +source = "git+https://github.com/privacy-scaling-explorations/poseidon.git?tag=v2023_04_20#807f8f555313f726ca03bdf941f798098f488ba4" +dependencies = [ + "halo2curves 0.3.2", + "subtle", +] + +[[package]] +name = "postgres" +version = "0.19.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0bed5017bc2ff49649c0075d0d7a9d676933c1292480c1d137776fb205b5cd18" +dependencies = [ + "bytes", + "fallible-iterator", + "futures-util", + "log", + "tokio", + "tokio-postgres", +] + +[[package]] +name = "postgres-protocol" +version = "0.6.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "78b7fa9f396f51dffd61546fd8573ee20592287996568e6175ceb0f8699ad75d" +dependencies = [ + "base64 0.21.2", + "byteorder", + "bytes", + "fallible-iterator", + "hmac", + "md-5", + "memchr", + "rand 0.8.5", + "sha2 0.10.7", + "stringprep", +] + +[[package]] +name = "postgres-types" +version = "0.2.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f028f05971fe20f512bcc679e2c10227e57809a3af86a7606304435bc8896cd6" +dependencies = [ + "bytes", + "fallible-iterator", + "postgres-protocol", +] + +[[package]] +name = "ppv-lite86" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" + +[[package]] +name = "precomputed-hash" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "925383efa346730478fb4838dbe9137d2a47675ad789c546d150a6e1dd4ab31c" + +[[package]] +name = "prettyplease" +version = "0.2.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9825a04601d60621feed79c4e6b56d65db77cdca55cef43b46b0de1096d1c282" +dependencies = [ + "proc-macro2", + "syn 2.0.22", +] + +[[package]] +name = "primal-check" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9df7f93fd637f083201473dab4fee2db4c429d32e55e3299980ab3957ab916a0" +dependencies = [ + "num-integer", +] + +[[package]] +name = "primitive-types" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9f3486ccba82358b11a77516035647c34ba167dfa53312630de83b12bd4f3d66" +dependencies = [ + "fixed-hash", + "impl-codec", + "impl-rlp", + "impl-serde", + "scale-info", + "uint", +] + +[[package]] +name = "proc-macro-crate" +version = "1.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f4c021e1093a56626774e81216a4ce732a735e5bad4868a03f3ed65ca0c3919" +dependencies = [ + "once_cell", + "toml_edit", +] + +[[package]] +name = "proc-macro-error" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "da25490ff9892aab3fcf7c36f08cfb902dd3e71ca0f9f9517bea02a73a5ce38c" +dependencies = [ + "proc-macro-error-attr", + "proc-macro2", + "quote", + "syn 1.0.109", + "version_check", +] + +[[package]] +name = "proc-macro-error-attr" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a1be40180e52ecc98ad80b184934baf3d0d29f979574e439af5a55274b35f869" +dependencies = [ + "proc-macro2", + "quote", + "version_check", +] + +[[package]] +name = "proc-macro2" +version = "1.0.63" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7b368fba921b0dce7e60f5e04ec15e565b3303972b42bcfde1d0713b881959eb" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "proptest" +version = "1.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7c003ac8c77cb07bb74f5f198bce836a689bcd5a42574612bf14d17bfd08c20e" +dependencies = [ + "bitflags 2.4.0", + "lazy_static", + "num-traits", + "rand 0.8.5", + "rand_chacha", + "rand_xorshift", + "regex-syntax 0.7.2", + "unarray", +] + +[[package]] +name = "prost" +version = "0.11.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b82eaa1d779e9a4bc1c3217db8ffbeabaae1dca241bf70183242128d48681cd" +dependencies = [ + "bytes", + "prost-derive", +] + +[[package]] +name = "prost-derive" +version = "0.11.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e5d2d8d10f3c6ded6da8b05b5fb3b8a5082514344d56c9f871412d29b4e075b4" +dependencies = [ + "anyhow", + "itertools 0.10.5", + "proc-macro2", + "quote", + "syn 1.0.109", +] + +[[package]] +name = "pyo3" +version = "0.18.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3b1ac5b3731ba34fdaa9785f8d74d17448cd18f30cf19e0c7e7b1fdb5272109" +dependencies = [ + "cfg-if", + "indoc", + "libc", + "memoffset 0.8.0", + "parking_lot", + "pyo3-build-config", + "pyo3-ffi", + "pyo3-macros", + "unindent", +] + +[[package]] +name = "pyo3-asyncio" +version = "0.18.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d3564762e37035cfc486228e10b0528460fa026d681b5763873c693aa0d5c260" +dependencies = [ + "futures", + "once_cell", + "pin-project-lite", + "pyo3", + "pyo3-asyncio-macros", + "tokio", +] + +[[package]] +name = "pyo3-asyncio-macros" +version = "0.18.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "be72d4cd43a27530306bd0d20d3932182fbdd072c6b98d3638bc37efb9d559dd" +dependencies = [ + "proc-macro2", + "quote", + "syn 1.0.109", +] + +[[package]] +name = "pyo3-build-config" +version = "0.18.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9cb946f5ac61bb61a5014924910d936ebd2b23b705f7a4a3c40b05c720b079a3" +dependencies = [ + "once_cell", + "target-lexicon", +] + +[[package]] +name = "pyo3-ffi" +version = "0.18.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fd4d7c5337821916ea2a1d21d1092e8443cf34879e53a0ac653fbb98f44ff65c" +dependencies = [ + "libc", + "pyo3-build-config", +] + +[[package]] +name = "pyo3-log" +version = "0.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c94ff6535a6bae58d7d0b85e60d4c53f7f84d0d0aa35d6a28c3f3e70bfe51444" +dependencies = [ + "arc-swap", + "log", + "pyo3", +] + +[[package]] +name = "pyo3-macros" +version = "0.18.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a9d39c55dab3fc5a4b25bbd1ac10a2da452c4aca13bb450f22818a002e29648d" +dependencies = [ + "proc-macro2", + "pyo3-macros-backend", + "quote", + "syn 1.0.109", +] + +[[package]] +name = "pyo3-macros-backend" +version = "0.18.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "97daff08a4c48320587b5224cc98d609e3c27b6d437315bd40b605c98eeb5918" +dependencies = [ + "proc-macro2", + "quote", + "syn 1.0.109", +] + +[[package]] +name = "quote" +version = "1.0.28" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1b9ab9c7eadfd8df19006f1cf1a4aed13540ed5cbc047010ece5826e10825488" +dependencies = [ + "proc-macro2", +] + +[[package]] +name = "radium" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "def50a86306165861203e7f84ecffbbdfdea79f0e51039b33de1e952358c47ac" + +[[package]] +name = "radium" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc33ff2d4973d518d823d61aa239014831e521c75da58e3df4840d3f47749d09" + +[[package]] +name = "rand" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "552840b97013b1a26992c11eac34bdd778e464601a4c2054b5f0bff7c6761293" +dependencies = [ + "fuchsia-cprng", + "libc", + "rand_core 0.3.1", + "rdrand", + "winapi", +] + +[[package]] +name = "rand" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" +dependencies = [ + "libc", + "rand_chacha", + "rand_core 0.6.4", +] + +[[package]] +name = "rand_chacha" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" +dependencies = [ + "ppv-lite86", + "rand_core 0.6.4", +] + +[[package]] +name = "rand_core" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7a6fdeb83b075e8266dcc8762c22776f6877a63111121f5f8c7411e5be7eed4b" +dependencies = [ + "rand_core 0.4.2", +] + +[[package]] +name = "rand_core" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9c33a3c44ca05fa6f1807d8e6743f3824e8509beca625669633be0acbdf509dc" + +[[package]] +name = "rand_core" +version = "0.6.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" +dependencies = [ + "getrandom", +] + +[[package]] +name = "rand_distr" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32cb0b9bc82b0a0876c2dd994a7e7a2683d3e7390ca40e6886785ef0c7e3ee31" +dependencies = [ + "num-traits", + "rand 0.8.5", +] + +[[package]] +name = "rand_xorshift" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d25bf25ec5ae4a3f1b92f929810509a2f53d7dca2f50b794ff57e3face536c8f" +dependencies = [ + "rand_core 0.6.4", +] + +[[package]] +name = "rawpointer" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3" + +[[package]] +name = "rayon" +version = "1.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d2df5196e37bcc87abebc0053e20787d73847bb33134a69841207dd0a47f03b" +dependencies = [ + "either", + "rayon-core", +] + +[[package]] +name = "rayon-core" +version = "1.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4b8f95bd6966f5c87776639160a66bd8ab9895d9d4ab01ddba9fc60661aebe8d" +dependencies = [ + "crossbeam-channel", + "crossbeam-deque", + "crossbeam-utils", + "num_cpus", +] + +[[package]] +name = "rdrand" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "678054eb77286b51581ba43620cc911abf02758c91f93f479767aed0f90458b2" +dependencies = [ + "rand_core 0.3.1", +] + +[[package]] +name = "redox_syscall" +version = "0.2.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fb5a58c1855b4b6819d59012155603f0b22ad30cad752600aadfcb695265519a" +dependencies = [ + "bitflags 1.3.2", +] + +[[package]] +name = "redox_syscall" +version = "0.3.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "567664f262709473930a4bf9e51bf2ebf3348f2e748ccc50dea20646858f8f29" +dependencies = [ + "bitflags 1.3.2", +] + +[[package]] +name = "redox_users" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b033d837a7cf162d7993aded9304e30a83213c648b6e389db233191f891e5c2b" +dependencies = [ + "getrandom", + "redox_syscall 0.2.16", + "thiserror", +] + +[[package]] +name = "regex" +version = "1.8.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d0ab3ca65655bb1e41f2a8c8cd662eb4fb035e67c3f78da1d61dffe89d07300f" +dependencies = [ + "aho-corasick", + "memchr", + "regex-syntax 0.7.2", +] + +[[package]] +name = "regex-syntax" +version = "0.6.29" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f162c6dd7b008981e4d40210aca20b4bd0f9b60ca9271061b07f78537722f2e1" + +[[package]] +name = "regex-syntax" +version = "0.7.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "436b050e76ed2903236f032a59761c1eb99e1b0aead2c257922771dab1fc8c78" + +[[package]] +name = "remove_dir_all" +version = "0.5.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3acd125665422973a33ac9d3dd2df85edad0f4ae9b00dafb1a05e43a9f5ef8e7" +dependencies = [ + "winapi", +] + +[[package]] +name = "reqwest" +version = "0.11.22" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "046cd98826c46c2ac8ddecae268eb5c2e58628688a5fc7a2643704a73faba95b" +dependencies = [ + "base64 0.21.2", + "bytes", + "encoding_rs", + "futures-core", + "futures-util", + "h2", + "http", + "http-body", + "hyper", + "hyper-tls", + "ipnet", + "js-sys", + "log", + "mime", + "mime_guess", + "native-tls", + "once_cell", + "percent-encoding", + "pin-project-lite", + "serde", + "serde_json", + "serde_urlencoded", + "system-configuration", + "tokio", + "tokio-native-tls", + "tokio-util", + "tower-service", + "url", + "wasm-bindgen", + "wasm-bindgen-futures", + "wasm-streams", + "web-sys", + "winreg", +] + +[[package]] +name = "revm" +version = "3.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f293f351c4c203d321744e54ed7eed3d2b6eef4c140228910dde3ac9a5ea8031" +dependencies = [ + "auto_impl", + "revm-interpreter", + "revm-precompile", +] + +[[package]] +name = "revm-interpreter" +version = "1.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a53980a26f9b5a66d13511c35074d4b53631e157850a1d7cf1af4efc2c2b72c9" +dependencies = [ + "derive_more", + "enumn", + "revm-primitives", + "sha3 0.10.8", +] + +[[package]] +name = "revm-precompile" +version = "2.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "41320af3bd6a65153d38eb1d3638ba89104cc9513c7feedb2d8510e8307dab29" +dependencies = [ + "k256", + "num", + "once_cell", + "revm-primitives", + "ripemd", + "secp256k1", + "sha2 0.10.7", + "sha3 0.10.8", + "substrate-bn", +] + +[[package]] +name = "revm-primitives" +version = "1.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "304d998f466ffef72d76c7f20b05bf08a96801736a6fb1fdef47d49a292618df" +dependencies = [ + "auto_impl", + "bitvec 1.0.1", + "bytes", + "derive_more", + "enumn", + "fixed-hash", + "hashbrown 0.13.2", + "hex", + "hex-literal", + "primitive-types", + "rlp", + "ruint", + "sha3 0.10.8", +] + +[[package]] +name = "rfc6979" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f8dd2a808d456c4a54e300a23e9f5a67e122c3024119acbfd73e3bf664491cb2" +dependencies = [ + "hmac", + "subtle", +] + +[[package]] +name = "ripemd" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bd124222d17ad93a644ed9d011a40f4fb64aa54275c08cc216524a9ea82fb09f" +dependencies = [ + "digest 0.10.7", +] + +[[package]] +name = "rlp" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bb919243f34364b6bd2fc10ef797edbfa75f33c252e7998527479c6d6b47e1ec" +dependencies = [ + "bytes", + "rlp-derive", + "rustc-hex", +] + +[[package]] +name = "rlp-derive" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e33d7b2abe0c340d8797fe2907d3f20d3b5ea5908683618bfe80df7f621f672a" +dependencies = [ + "proc-macro2", + "quote", + "syn 1.0.109", +] + +[[package]] +name = "ruint" +version = "1.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "95294d6e3a6192f3aabf91c38f56505a625aa495533442744185a36d75a790c4" +dependencies = [ + "alloy-rlp", + "ark-ff 0.3.0", + "ark-ff 0.4.2", + "bytes", + "fastrlp", + "num-bigint", + "parity-scale-codec", + "primitive-types", + "proptest", + "rand 0.8.5", + "rlp", + "ruint-macro", + "serde", + "valuable", + "zeroize", +] + +[[package]] +name = "ruint-macro" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e666a5496a0b2186dbcd0ff6106e29e093c15591bde62c20d3842007c6978a09" + +[[package]] +name = "rustacuda" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "47208516ab5338b592d63560e90eaef405d0ec880347eaf7742d893b0a31e228" +dependencies = [ + "bitflags 1.3.2", + "cuda-driver-sys", + "rustacuda_core", + "rustacuda_derive", +] + +[[package]] +name = "rustacuda_core" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d3858b08976dc2f860c5efbbb48cdcb0d4fafca92a6ac0898465af16c0dbe848" + +[[package]] +name = "rustacuda_derive" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "43ce8670a1a1d0fc2514a3b846dacdb65646f9bd494b6674cfacbb4ce430bd7e" +dependencies = [ + "proc-macro2", + "quote", + "syn 1.0.109", +] + +[[package]] +name = "rustc-hex" +version = "2.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3e75f6a532d0fd9f7f13144f392b6ad56a32696bfcd9c78f797f16bbb6f072d6" + +[[package]] +name = "rustc_version" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f0dfe2087c51c460008730de8b57e6a320782fbfb312e1f4d520e6c6fae155ee" +dependencies = [ + "semver 0.11.0", +] + +[[package]] +name = "rustc_version" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bfa0f585226d2e68097d4f95d113b15b83a82e819ab25717ec0590d9584ef366" +dependencies = [ + "semver 1.0.17", +] + +[[package]] +name = "rustfft" +version = "6.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e17d4f6cbdb180c9f4b2a26bbf01c4e647f1e1dea22fe8eb9db54198b32f9434" +dependencies = [ + "num-complex", + "num-integer", + "num-traits", + "primal-check", + "strength_reduce", + "transpose", + "version_check", +] + +[[package]] +name = "rustix" +version = "0.37.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b96e891d04aa506a6d1f318d2771bcb1c7dfda84e126660ace067c9b474bb2c0" +dependencies = [ + "bitflags 1.3.2", + "errno", + "io-lifetimes", + "libc", + "linux-raw-sys", + "windows-sys 0.48.0", +] + +[[package]] +name = "rustversion" +version = "1.0.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4f3208ce4d8448b3f3e7d168a73f5e0c43a61e32930de3bceeccedb388b6bf06" + +[[package]] +name = "ryu" +version = "1.0.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f91339c0467de62360649f8d3e185ca8de4224ff281f66000de5eb2a77a79041" + +[[package]] +name = "safetensors" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d93279b86b3de76f820a8854dd06cbc33cfa57a417b19c47f6a25280112fb1df" +dependencies = [ + "serde", + "serde_json", +] + +[[package]] +name = "salsa20" +version = "0.10.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "97a22f5af31f73a954c10289c93e8a50cc23d971e80ee446f1f6f7137a088213" +dependencies = [ + "cipher", +] + +[[package]] +name = "same-file" +version = "1.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "93fc1dc3aaa9bfed95e02e6eadabb4baf7e3078b0bd1b4d7b6b0b68378900502" +dependencies = [ + "winapi-util", +] + +[[package]] +name = "scale-info" +version = "2.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ad560913365790f17cbf12479491169f01b9d46d29cfc7422bf8c64bdc61b731" +dependencies = [ + "cfg-if", + "derive_more", + "parity-scale-codec", + "scale-info-derive", +] + +[[package]] +name = "scale-info-derive" +version = "2.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "19df9bd9ace6cc2fe19387c96ce677e823e07d017ceed253e7bb3d1d1bd9c73b" +dependencies = [ + "proc-macro-crate", + "proc-macro2", + "quote", + "syn 1.0.109", +] + +[[package]] +name = "scan_fmt" +version = "0.2.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b53b0a5db882a8e2fdaae0a43f7b39e7e9082389e978398bdf223a55b581248" +dependencies = [ + "regex", +] + +[[package]] +name = "schannel" +version = "0.1.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "713cfb06c7059f3588fb8044c0fad1d09e3c01d225e25b9220dbfdcf16dbb1b3" +dependencies = [ + "windows-sys 0.42.0", +] + +[[package]] +name = "scoped-tls" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e1cf6437eb19a8f4a6cc0f7dca544973b0b78843adbfeb3683d1a94a0024a294" + +[[package]] +name = "scopeguard" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d29ab0c6d3fc0ee92fe66e2d99f700eab17a8d57d1c1d3b748380fb20baa78cd" + +[[package]] +name = "scrypt" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9f9e24d2b632954ded8ab2ef9fea0a0c769ea56ea98bddbafbad22caeeadf45d" +dependencies = [ + "hmac", + "pbkdf2 0.11.0", + "salsa20", + "sha2 0.10.7", +] + +[[package]] +name = "sec1" +version = "0.7.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f0aec48e813d6b90b15f0b8948af3c63483992dee44c03e9930b3eebdabe046e" +dependencies = [ + "base16ct", + "der", + "generic-array", + "pkcs8", + "subtle", + "zeroize", +] + +[[package]] +name = "secp256k1" +version = "0.27.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "25996b82292a7a57ed3508f052cfff8640d38d32018784acd714758b43da9c8f" +dependencies = [ + "secp256k1-sys", +] + +[[package]] +name = "secp256k1-sys" +version = "0.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "70a129b9e9efbfb223753b9163c4ab3b13cff7fd9c7f010fbac25ab4099fa07e" +dependencies = [ + "cc", +] + +[[package]] +name = "security-framework" +version = "2.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1fc758eb7bffce5b308734e9b0c1468893cae9ff70ebf13e7090be8dcbcc83a8" +dependencies = [ + "bitflags 1.3.2", + "core-foundation", + "core-foundation-sys", + "libc", + "security-framework-sys", +] + +[[package]] +name = "security-framework-sys" +version = "2.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f51d0c0d83bec45f16480d0ce0058397a69e48fcdc52d1dc8855fb68acbd31a7" +dependencies = [ + "core-foundation-sys", + "libc", +] + +[[package]] +name = "semver" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f301af10236f6df4160f7c3f04eec6dbc70ace82d23326abad5edee88801c6b6" +dependencies = [ + "semver-parser", +] + +[[package]] +name = "semver" +version = "1.0.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bebd363326d05ec3e2f532ab7660680f3b02130d780c299bca73469d521bc0ed" +dependencies = [ + "serde", +] + +[[package]] +name = "semver-parser" +version = "0.10.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "00b0bef5b7f9e0df16536d3961cfb6e84331c065b4066afb39768d0e319411f7" +dependencies = [ + "pest", +] + +[[package]] +name = "send_wrapper" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f638d531eccd6e23b980caf34876660d38e265409d8e99b397ab71eb3612fad0" + +[[package]] +name = "send_wrapper" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cd0b0ec5f1c1ca621c432a25813d8d60c88abe6d3e08a3eb9cf37d97a0fe3d73" + +[[package]] +name = "seq-macro" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6b44e8fc93a14e66336d230954dda83d18b4605ccace8fe09bc7514a71ad0bc" + +[[package]] +name = "serde" +version = "1.0.164" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9e8c8cf938e98f769bc164923b06dce91cea1751522f46f8466461af04c9027d" +dependencies = [ + "serde_derive", +] + +[[package]] +name = "serde-wasm-bindgen" +version = "0.4.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3b4c031cd0d9014307d82b8abf653c0290fbdaeb4c02d00c63cf52f728628bf" +dependencies = [ + "js-sys", + "serde", + "wasm-bindgen", +] + +[[package]] +name = "serde_arrays" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "38636132857f68ec3d5f3eb121166d2af33cb55174c4d5ff645db6165cbef0fd" +dependencies = [ + "serde", +] + +[[package]] +name = "serde_cbor" +version = "0.11.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2bef2ebfde456fb76bbcf9f59315333decc4fda0b2b44b420243c11e0f5ec1f5" +dependencies = [ + "half 1.8.2", + "serde", +] + +[[package]] +name = "serde_derive" +version = "1.0.164" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d9735b638ccc51c28bf6914d90a2e9725b377144fc612c49a611fddd1b631d68" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.22", +] + +[[package]] +name = "serde_json" +version = "1.0.99" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "46266871c240a00b8f503b877622fe33430b3c7d963bdc0f2adc511e54a1eae3" +dependencies = [ + "itoa", + "ryu", + "serde", +] + +[[package]] +name = "serde_spanned" +version = "0.6.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "96426c9936fd7a0124915f9185ea1d20aa9445cc9821142f0a73bc9207a2e186" +dependencies = [ + "serde", +] + +[[package]] +name = "serde_urlencoded" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d3491c14715ca2294c4d6a88f15e84739788c1d030eed8c110436aafdaa2f3fd" +dependencies = [ + "form_urlencoded", + "itoa", + "ryu", + "serde", +] + +[[package]] +name = "sha1" +version = "0.10.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3bf829a2d51ab4a5ddf1352d8470c140cadc8301b2ae1789db023f01cedd6ba" +dependencies = [ + "cfg-if", + "cpufeatures", + "digest 0.10.7", +] + +[[package]] +name = "sha2" +version = "0.9.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4d58a1e1bf39749807d89cf2d98ac2dfa0ff1cb3faa38fbb64dd88ac8013d800" +dependencies = [ + "block-buffer 0.9.0", + "cfg-if", + "cpufeatures", + "digest 0.9.0", + "opaque-debug", +] + +[[package]] +name = "sha2" +version = "0.10.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "479fb9d862239e610720565ca91403019f2f00410f1864c5aa7479b950a76ed8" +dependencies = [ + "cfg-if", + "cpufeatures", + "digest 0.10.7", +] + +[[package]] +name = "sha3" +version = "0.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f81199417d4e5de3f04b1e871023acea7389672c4135918f05aa9cbf2f2fa809" +dependencies = [ + "block-buffer 0.9.0", + "digest 0.9.0", + "keccak", + "opaque-debug", +] + +[[package]] +name = "sha3" +version = "0.10.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75872d278a8f37ef87fa0ddbda7802605cb18344497949862c0d4dcb291eba60" +dependencies = [ + "digest 0.10.7", + "keccak", +] + +[[package]] +name = "shellexpand" +version = "3.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "da03fa3b94cc19e3ebfc88c4229c49d8f08cdbd1228870a45f0ffdf84988e14b" +dependencies = [ + "dirs", +] + +[[package]] +name = "signature" +version = "2.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5e1788eed21689f9cf370582dfc467ef36ed9c707f073528ddafa8d83e3b8500" +dependencies = [ + "digest 0.10.7", + "rand_core 0.6.4", +] + +[[package]] +name = "siphasher" +version = "0.3.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7bd3e3206899af3f8b12af284fafc038cc1dc2b41d1b89dd17297221c5d225de" + +[[package]] +name = "slab" +version = "0.4.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6528351c9bc8ab22353f9d776db39a20288e8d6c37ef8cfe3317cf875eecfc2d" +dependencies = [ + "autocfg", +] + +[[package]] +name = "smallvec" +version = "1.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a507befe795404456341dfab10cef66ead4c041f62b8b11bbb92bffe5d0953e0" + +[[package]] +name = "smol_str" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "74212e6bbe9a4352329b2f68ba3130c15a3f26fe88ff22dbdc6cdd58fa85e99c" +dependencies = [ + "serde", +] + +[[package]] +name = "snark-verifier" +version = "0.1.0" +source = "git+https://github.com/zkonduit/snark-verifier?branch=ac/chunked-mv-lookup#a1ac764143960023551e99da000157682ff4d970" +dependencies = [ + "ecc", + "halo2_proofs", + "halo2curves 0.1.0", + "hex", + "itertools 0.10.5", + "lazy_static", + "num-bigint", + "num-integer", + "num-traits", + "poseidon", + "rand 0.8.5", + "revm", + "serde", + "sha3 0.10.8", +] + +[[package]] +name = "socket2" +version = "0.4.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "64a4a911eed85daf18834cfaa86a79b7d266ff93ff5ba14005426219480ed662" +dependencies = [ + "libc", + "winapi", +] + +[[package]] +name = "socket2" +version = "0.5.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2538b18701741680e0322a2302176d3253a35388e2e62f172f64f4f16605f877" +dependencies = [ + "libc", + "windows-sys 0.48.0", +] + +[[package]] +name = "solang-parser" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4a94494913728908efa7a25a2dd2e4f037e714897985c24273c40596638ed909" +dependencies = [ + "itertools 0.10.5", + "lalrpop", + "lalrpop-util", + "phf", + "thiserror", + "unicode-xid", +] + +[[package]] +name = "spin" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6e63cff320ae2c57904679ba7cb63280a3dc4613885beafb148ee7bf9aa9042d" + +[[package]] +name = "spki" +version = "0.7.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9d1e996ef02c474957d681f1b05213dfb0abab947b446a62d37770b23500184a" +dependencies = [ + "base64ct", + "der", +] + +[[package]] +name = "spmc" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "02a8428da277a8e3a15271d79943e80ccc2ef254e78813a166a08d65e4c3ece5" + +[[package]] +name = "static_assertions" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f" + +[[package]] +name = "strength_reduce" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fe895eb47f22e2ddd4dabc02bce419d2e643c8e3b585c78158b349195bc24d82" + +[[package]] +name = "string-interner" +version = "0.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "91e2531d8525b29b514d25e275a43581320d587b86db302b9a7e464bac579648" +dependencies = [ + "cfg-if", + "hashbrown 0.11.2", + "serde", +] + +[[package]] +name = "string_cache" +version = "0.8.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f91138e76242f575eb1d3b38b4f1362f10d3a43f47d182a5b359af488a02293b" +dependencies = [ + "new_debug_unreachable", + "once_cell", + "parking_lot", + "phf_shared 0.10.0", + "precomputed-hash", +] + +[[package]] +name = "stringprep" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "db3737bde7edce97102e0e2b15365bf7a20bfdb5f60f4f9e8d7004258a51a8da" +dependencies = [ + "unicode-bidi", + "unicode-normalization", +] + +[[package]] +name = "strsim" +version = "0.9.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6446ced80d6c486436db5c078dde11a9f73d42b57fb273121e160b84f63d894c" + +[[package]] +name = "strsim" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "73473c0e59e6d5812c5dfe2a064a6444949f089e20eec9a2e5506596494e4623" + +[[package]] +name = "strum" +version = "0.24.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "063e6045c0e62079840579a7e47a355ae92f60eb74daaf156fb1e84ba164e63f" +dependencies = [ + "strum_macros", +] + +[[package]] +name = "strum_macros" +version = "0.24.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e385be0d24f186b4ce2f9982191e7101bb737312ad61c1f2f984f34bcf85d59" +dependencies = [ + "heck", + "proc-macro2", + "quote", + "rustversion", + "syn 1.0.109", +] + +[[package]] +name = "substrate-bn" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72b5bbfa79abbae15dd642ea8176a21a635ff3c00059961d1ea27ad04e5b441c" +dependencies = [ + "byteorder", + "crunchy", + "lazy_static", + "rand 0.8.5", + "rustc-hex", +] + +[[package]] +name = "subtle" +version = "2.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "81cdd64d312baedb58e21336b31bc043b77e01cc99033ce76ef539f78e965ebc" + +[[package]] +name = "syn" +version = "1.0.109" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "syn" +version = "2.0.22" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2efbeae7acf4eabd6bcdcbd11c92f45231ddda7539edc7806bd1a04a03b24616" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "system-configuration" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba3a3adc5c275d719af8cb4272ea1c4a6d668a777f37e115f6d11ddbc1c8e0e7" +dependencies = [ + "bitflags 1.3.2", + "core-foundation", + "system-configuration-sys", +] + +[[package]] +name = "system-configuration-sys" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a75fb188eb626b924683e3b95e3a48e63551fcfb51949de2f06a9d91dbee93c9" +dependencies = [ + "core-foundation-sys", + "libc", +] + +[[package]] +name = "tabbycat" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c45590f0f859197b4545be1b17b2bc3cc7bb075f7d1cc0ea1dc6521c0bf256a3" +dependencies = [ + "anyhow", + "derive_builder", + "regex", +] + +[[package]] +name = "tabled" +version = "0.12.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0ce69a5028cd9576063ec1f48edb2c75339fd835e6094ef3e05b3a079bf594a6" +dependencies = [ + "ansi-str", + "ansitok", + "papergrid", + "tabled_derive", + "unicode-width", +] + +[[package]] +name = "tabled_derive" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "99f688a08b54f4f02f0a3c382aefdb7884d3d69609f785bd253dc033243e3fe4" +dependencies = [ + "heck", + "proc-macro-error", + "proc-macro2", + "quote", + "syn 1.0.109", +] + +[[package]] +name = "tap" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "55937e1799185b12863d447f42597ed69d9928686b8d88a1df17376a097d8369" + +[[package]] +name = "tar" +version = "0.4.38" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4b55807c0344e1e6c04d7c965f5289c39a8d94ae23ed5c0b57aabac549f871c6" +dependencies = [ + "filetime", + "libc", + "xattr", +] + +[[package]] +name = "target-lexicon" +version = "0.12.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1b1c7f239eb94671427157bd93b3694320f3668d4e1eff08c7285366fd777fac" + +[[package]] +name = "tch" +version = "0.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0ed5dddab3812892bf5fb567136e372ea49f31672931e21cec967ca68aec03da" +dependencies = [ + "half 2.2.1", + "lazy_static", + "libc", + "ndarray", + "rand 0.8.5", + "safetensors", + "thiserror", + "torch-sys", + "zip", +] + +[[package]] +name = "tempdir" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "15f2b5fb00ccdf689e0149d1b1b3c03fead81c2b37735d812fa8bddbbf41b6d8" +dependencies = [ + "rand 0.4.6", + "remove_dir_all", +] + +[[package]] +name = "tempfile" +version = "3.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "31c0432476357e58790aaa47a8efb0c5138f137343f3b5f23bd36a27e3b0a6d6" +dependencies = [ + "autocfg", + "cfg-if", + "fastrand", + "redox_syscall 0.3.5", + "rustix", + "windows-sys 0.48.0", +] + +[[package]] +name = "term" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c59df8ac95d96ff9bede18eb7300b0fda5e5d8d90960e76f8e14ae765eedbf1f" +dependencies = [ + "dirs-next", + "rustversion", + "winapi", +] + +[[package]] +name = "termcolor" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6093bad37da69aab9d123a8091e4be0aa4a03e4d601ec641c327398315f62b64" +dependencies = [ + "winapi-util", +] + +[[package]] +name = "test-case" +version = "2.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "21d6cf5a7dffb3f9dceec8e6b8ca528d9bd71d36c9f074defb548ce161f598c0" +dependencies = [ + "test-case-macros", +] + +[[package]] +name = "test-case-macros" +version = "2.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e45b7bf6e19353ddd832745c8fcf77a17a93171df7151187f26623f2b75b5b26" +dependencies = [ + "cfg-if", + "proc-macro-error", + "proc-macro2", + "quote", + "syn 1.0.109", +] + +[[package]] +name = "textwrap" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d326610f408c7a4eb6f51c37c330e496b08506c9457c9d34287ecc38809fb060" +dependencies = [ + "unicode-width", +] + +[[package]] +name = "thiserror" +version = "1.0.40" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "978c9a314bd8dc99be594bc3c175faaa9794be04a5a5e153caba6915336cebac" +dependencies = [ + "thiserror-impl", +] + +[[package]] +name = "thiserror-impl" +version = "1.0.40" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f9456a42c5b0d803c8cd86e73dd7cc9edd429499f37a3550d286d5e86720569f" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.22", +] + +[[package]] +name = "time" +version = "0.3.23" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "59e399c068f43a5d116fedaf73b203fa4f9c519f17e2b34f63221d3792f81446" +dependencies = [ + "itoa", + "serde", + "time-core", + "time-macros", +] + +[[package]] +name = "time-core" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7300fbefb4dadc1af235a9cef3737cea692a9d97e1b9cbcd4ebdae6f8868e6fb" + +[[package]] +name = "time-macros" +version = "0.2.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "96ba15a897f3c86766b757e5ac7221554c6750054d74d5b28844fce5fb36a6c4" +dependencies = [ + "time-core", +] + +[[package]] +name = "tiny-keccak" +version = "2.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2c9d3793400a45f954c52e73d068316d76b6f4e36977e3fcebb13a2721e80237" +dependencies = [ + "crunchy", +] + +[[package]] +name = "tinytemplate" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "be4d6b5f19ff7664e8c98d03e2139cb510db9b0a60b55f8e8709b689d939b6bc" +dependencies = [ + "serde", + "serde_json", +] + +[[package]] +name = "tinyvec" +version = "1.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "87cc5ceb3875bb20c2890005a4e226a4651264a5c75edb2421b52861a0a0cb50" +dependencies = [ + "tinyvec_macros", +] + +[[package]] +name = "tinyvec_macros" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" + +[[package]] +name = "tokio" +version = "1.28.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94d7b1cfd2aa4011f2de74c2c4c63665e27a71006b0a192dcd2710272e73dfa2" +dependencies = [ + "autocfg", + "bytes", + "libc", + "mio", + "num_cpus", + "pin-project-lite", + "socket2 0.4.9", + "tokio-macros", + "windows-sys 0.48.0", +] + +[[package]] +name = "tokio-macros" +version = "2.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "630bdcf245f78637c13ec01ffae6187cca34625e8c63150d424b59e55af2675e" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.22", +] + +[[package]] +name = "tokio-native-tls" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bbae76ab933c85776efabc971569dd6119c580d8f5d448769dec1764bf796ef2" +dependencies = [ + "native-tls", + "tokio", +] + +[[package]] +name = "tokio-postgres" +version = "0.7.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6e89f6234aa8fd43779746012fcf53603cdb91fdd8399aa0de868c2d56b6dde1" +dependencies = [ + "async-trait", + "byteorder", + "bytes", + "fallible-iterator", + "futures-channel", + "futures-util", + "log", + "parking_lot", + "percent-encoding", + "phf", + "pin-project-lite", + "postgres-protocol", + "postgres-types", + "socket2 0.5.3", + "tokio", + "tokio-util", +] + +[[package]] +name = "tokio-util" +version = "0.7.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d68074620f57a0b21594d9735eb2e98ab38b17f80d3fcb189fca266771ca60d" +dependencies = [ + "bytes", + "futures-core", + "futures-sink", + "pin-project-lite", + "tokio", + "tracing", +] + +[[package]] +name = "toml" +version = "0.7.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1ebafdf5ad1220cb59e7d17cf4d2c72015297b75b19a10472f99b89225089240" +dependencies = [ + "serde", + "serde_spanned", + "toml_datetime", + "toml_edit", +] + +[[package]] +name = "toml_datetime" +version = "0.6.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7cda73e2f1397b1262d6dfdcef8aafae14d1de7748d66822d3bfeeb6d03e5e4b" +dependencies = [ + "serde", +] + +[[package]] +name = "toml_edit" +version = "0.19.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "266f016b7f039eec8a1a80dfe6156b633d208b9fccca5e4db1d6775b0c4e34a7" +dependencies = [ + "indexmap 2.0.0", + "serde", + "serde_spanned", + "toml_datetime", + "winnow", +] + +[[package]] +name = "torch-sys" +version = "0.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "803446f89fb877a117503dbfb8375b6a29fa8b0e0f44810fac3863c798ecef22" +dependencies = [ + "anyhow", + "cc", + "libc", + "zip", +] + +[[package]] +name = "tower-service" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6bc1c9ce2b5135ac7f93c72918fc37feb872bdc6a5533a8b85eb4b86bfdae52" + +[[package]] +name = "tracing" +version = "0.1.37" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8ce8c33a8d48bd45d624a6e523445fd21ec13d3653cd51f681abf67418f54eb8" +dependencies = [ + "cfg-if", + "pin-project-lite", + "tracing-attributes", + "tracing-core", +] + +[[package]] +name = "tracing-attributes" +version = "0.1.26" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5f4f31f56159e98206da9efd823404b79b6ef3143b4a7ab76e67b1751b25a4ab" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.22", +] + +[[package]] +name = "tracing-core" +version = "0.1.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0955b8137a1df6f1a2e9a37d8a6656291ff0297c1a97c24e0d8425fe2312f79a" +dependencies = [ + "once_cell", +] + +[[package]] +name = "tracing-futures" +version = "0.2.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "97d095ae15e245a057c8e8451bab9b3ee1e1f68e9ba2b4fbc18d0ac5237835f2" +dependencies = [ + "pin-project", + "tracing", +] + +[[package]] +name = "tract-core" +version = "0.20.23-pre" +source = "git+https://github.com/sonos/tract/?rev=ee98004a2d8d7851da7b9fce954b2a7a7181eccb#ee98004a2d8d7851da7b9fce954b2a7a7181eccb" +dependencies = [ + "anyhow", + "bit-set", + "derive-new", + "downcast-rs", + "dyn-clone", + "lazy_static", + "log", + "maplit", + "ndarray", + "num-complex", + "num-integer", + "num-traits", + "paste", + "rustfft", + "smallvec", + "tract-data", + "tract-linalg", +] + +[[package]] +name = "tract-data" +version = "0.20.23-pre" +source = "git+https://github.com/sonos/tract/?rev=ee98004a2d8d7851da7b9fce954b2a7a7181eccb#ee98004a2d8d7851da7b9fce954b2a7a7181eccb" +dependencies = [ + "anyhow", + "half 2.2.1", + "itertools 0.10.5", + "lazy_static", + "maplit", + "ndarray", + "nom", + "num-integer", + "num-traits", + "scan_fmt", + "smallvec", + "string-interner", +] + +[[package]] +name = "tract-hir" +version = "0.20.23-pre" +source = "git+https://github.com/sonos/tract/?rev=ee98004a2d8d7851da7b9fce954b2a7a7181eccb#ee98004a2d8d7851da7b9fce954b2a7a7181eccb" +dependencies = [ + "derive-new", + "log", + "tract-core", +] + +[[package]] +name = "tract-linalg" +version = "0.20.23-pre" +source = "git+https://github.com/sonos/tract/?rev=ee98004a2d8d7851da7b9fce954b2a7a7181eccb#ee98004a2d8d7851da7b9fce954b2a7a7181eccb" +dependencies = [ + "cc", + "derive-new", + "downcast-rs", + "dyn-clone", + "half 2.2.1", + "lazy_static", + "liquid", + "liquid-core", + "log", + "num-traits", + "paste", + "scan_fmt", + "smallvec", + "time", + "tract-data", + "unicode-normalization", + "walkdir", +] + +[[package]] +name = "tract-nnef" +version = "0.20.23-pre" +source = "git+https://github.com/sonos/tract/?rev=ee98004a2d8d7851da7b9fce954b2a7a7181eccb#ee98004a2d8d7851da7b9fce954b2a7a7181eccb" +dependencies = [ + "byteorder", + "flate2", + "log", + "nom", + "tar", + "tract-core", + "walkdir", +] + +[[package]] +name = "tract-onnx" +version = "0.20.23-pre" +source = "git+https://github.com/sonos/tract/?rev=ee98004a2d8d7851da7b9fce954b2a7a7181eccb#ee98004a2d8d7851da7b9fce954b2a7a7181eccb" +dependencies = [ + "bytes", + "derive-new", + "log", + "memmap2", + "num-integer", + "prost", + "smallvec", + "tract-hir", + "tract-nnef", + "tract-onnx-opl", +] + +[[package]] +name = "tract-onnx-opl" +version = "0.20.23-pre" +source = "git+https://github.com/sonos/tract/?rev=ee98004a2d8d7851da7b9fce954b2a7a7181eccb#ee98004a2d8d7851da7b9fce954b2a7a7181eccb" +dependencies = [ + "getrandom", + "log", + "rand 0.8.5", + "rand_distr", + "rustfft", + "tract-nnef", +] + +[[package]] +name = "transpose" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6522d49d03727ffb138ae4cbc1283d3774f0d10aa7f9bf52e6784c45daf9b23" +dependencies = [ + "num-integer", + "strength_reduce", +] + +[[package]] +name = "try-lock" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3528ecfd12c466c6f163363caf2d02a71161dd5e1cc6ae7b34207ea2d42d81ed" + +[[package]] +name = "typenum" +version = "1.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "497961ef93d974e23eb6f433eb5fe1b7930b659f06d12dec6fc44a8f554c0bba" + +[[package]] +name = "ucd-trie" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9e79c4d996edb816c91e4308506774452e55e95c3c9de07b6729e17e15a5ef81" + +[[package]] +name = "uint" +version = "0.9.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "76f64bba2c53b04fcab63c01a7d7427eadc821e3bc48c34dc9ba29c501164b52" +dependencies = [ + "byteorder", + "crunchy", + "hex", + "static_assertions", +] + +[[package]] +name = "unarray" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eaea85b334db583fe3274d12b4cd1880032beab409c0d774be044d4480ab9a94" + +[[package]] +name = "unicase" +version = "2.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f7d2d4dafb69621809a81864c9c1b864479e1235c0dd4e199924b9742439ed89" +dependencies = [ + "version_check", +] + +[[package]] +name = "unicode-bidi" +version = "0.3.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "92888ba5573ff080736b3648696b70cafad7d250551175acbaa4e0385b3e1460" + +[[package]] +name = "unicode-ident" +version = "1.0.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b15811caf2415fb889178633e7724bad2509101cde276048e013b9def5e51fa0" + +[[package]] +name = "unicode-normalization" +version = "0.1.22" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c5713f0fc4b5db668a2ac63cdb7bb4469d8c9fed047b1d0292cc7b0ce2ba921" +dependencies = [ + "tinyvec", +] + +[[package]] +name = "unicode-segmentation" +version = "1.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1dd624098567895118886609431a7c3b8f516e41d30e0643f03d94592a147e36" + +[[package]] +name = "unicode-width" +version = "0.1.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c0edd1e5b14653f783770bce4a4dabb4a5108a5370a5f5d8cfe8710c361f6c8b" + +[[package]] +name = "unicode-xid" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f962df74c8c05a667b5ee8bcf162993134c104e96440b663c8daa176dc772d8c" + +[[package]] +name = "unindent" +version = "0.1.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e1766d682d402817b5ac4490b3c3002d91dfa0d22812f341609f97b08757359c" + +[[package]] +name = "unzip-n" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c2e7e85a0596447f0f2ac090e16bc4c516c6fe91771fb0c0ccf7fa3dae896b9c" +dependencies = [ + "proc-macro2", + "quote", + "syn 1.0.109", +] + +[[package]] +name = "url" +version = "2.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "50bff7831e19200a85b17131d085c25d7811bc4e186efdaf54bbd132994a88cb" +dependencies = [ + "form_urlencoded", + "idna", + "percent-encoding", +] + +[[package]] +name = "utf8parse" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "711b9620af191e0cdc7468a8d14e709c3dcdb115b36f838e601583af800a370a" + +[[package]] +name = "uuid" +version = "0.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bc5cf98d8186244414c848017f0e2676b3fcb46807f6668a97dfe67359a3c4b7" +dependencies = [ + "getrandom", + "serde", +] + +[[package]] +name = "valuable" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "830b7e5d4d90034032940e4ace0d9a9a057e7a45cd94e6c007832e39edb82f6d" + +[[package]] +name = "vcpkg" +version = "0.2.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426" + +[[package]] +name = "version_check" +version = "0.9.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f" + +[[package]] +name = "vte" +version = "0.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6cbce692ab4ca2f1f3047fcf732430249c0e971bfdd2b234cf2c47ad93af5983" +dependencies = [ + "arrayvec 0.5.2", + "utf8parse", + "vte_generate_state_changes", +] + +[[package]] +name = "vte_generate_state_changes" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d257817081c7dffcdbab24b9e62d2def62e2ff7d00b1c20062551e6cccc145ff" +dependencies = [ + "proc-macro2", + "quote", +] + +[[package]] +name = "walkdir" +version = "2.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "36df944cda56c7d8d8b7496af378e6b16de9284591917d307c9b4d313c44e698" +dependencies = [ + "same-file", + "winapi-util", +] + +[[package]] +name = "want" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bfa7760aed19e106de2c7c0b581b509f2f25d3dacaf737cb82ac61bc6d760b0e" +dependencies = [ + "try-lock", +] + +[[package]] +name = "wasi" +version = "0.11.0+wasi-snapshot-preview1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" + +[[package]] +name = "wasm-bindgen" +version = "0.2.87" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7706a72ab36d8cb1f80ffbf0e071533974a60d0a308d01a5d0375bf60499a342" +dependencies = [ + "cfg-if", + "serde", + "serde_json", + "wasm-bindgen-macro", +] + +[[package]] +name = "wasm-bindgen-backend" +version = "0.2.87" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5ef2b6d3c510e9625e5fe6f509ab07d66a760f0885d858736483c32ed7809abd" +dependencies = [ + "bumpalo", + "log", + "once_cell", + "proc-macro2", + "quote", + "syn 2.0.22", + "wasm-bindgen-shared", +] + +[[package]] +name = "wasm-bindgen-console-logger" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7530a275e7faf7b5b83aabdf78244fb8d9a68a2ec4b26935a05ecc0c9b0185ed" +dependencies = [ + "log", + "wasm-bindgen", +] + +[[package]] +name = "wasm-bindgen-futures" +version = "0.4.37" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c02dbc21516f9f1f04f187958890d7e6026df8d16540b7ad9492bc34a67cea03" +dependencies = [ + "cfg-if", + "js-sys", + "wasm-bindgen", + "web-sys", +] + +[[package]] +name = "wasm-bindgen-macro" +version = "0.2.87" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dee495e55982a3bd48105a7b947fd2a9b4a8ae3010041b9e0faab3f9cd028f1d" +dependencies = [ + "quote", + "wasm-bindgen-macro-support", +] + +[[package]] +name = "wasm-bindgen-macro-support" +version = "0.2.87" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "54681b18a46765f095758388f2d0cf16eb8d4169b639ab575a8f5693af210c7b" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.22", + "wasm-bindgen-backend", + "wasm-bindgen-shared", +] + +[[package]] +name = "wasm-bindgen-rayon" +version = "1.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df87c67450805c305d3ae44a3ac537b0253d029153c25afc3ecd2edc36ccafb1" +dependencies = [ + "js-sys", + "rayon", + "spmc", + "wasm-bindgen", +] + +[[package]] +name = "wasm-bindgen-shared" +version = "0.2.87" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ca6ad05a4870b2bf5fe995117d3728437bd27d7cd5f06f13c17443ef369775a1" + +[[package]] +name = "wasm-bindgen-test" +version = "0.3.37" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6e6e302a7ea94f83a6d09e78e7dc7d9ca7b186bc2829c24a22d0753efd680671" +dependencies = [ + "console_error_panic_hook", + "js-sys", + "scoped-tls", + "wasm-bindgen", + "wasm-bindgen-futures", + "wasm-bindgen-test-macro", +] + +[[package]] +name = "wasm-bindgen-test-macro" +version = "0.3.37" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ecb993dd8c836930ed130e020e77d9b2e65dd0fbab1b67c790b0f5d80b11a575" +dependencies = [ + "proc-macro2", + "quote", +] + +[[package]] +name = "wasm-streams" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b4609d447824375f43e1ffbc051b50ad8f4b3ae8219680c94452ea05eb240ac7" +dependencies = [ + "futures-util", + "js-sys", + "wasm-bindgen", + "wasm-bindgen-futures", + "web-sys", +] + +[[package]] +name = "web-sys" +version = "0.3.64" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9b85cbef8c220a6abc02aefd892dfc0fc23afb1c6a426316ec33253a3877249b" +dependencies = [ + "js-sys", + "wasm-bindgen", +] + +[[package]] +name = "winapi" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c839a674fcd7a98952e593242ea400abe93992746761e38641405d28b00f419" +dependencies = [ + "winapi-i686-pc-windows-gnu", + "winapi-x86_64-pc-windows-gnu", +] + +[[package]] +name = "winapi-i686-pc-windows-gnu" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" + +[[package]] +name = "winapi-util" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "70ec6ce85bb158151cae5e5c87f95a8e97d2c0c4b001223f33a334e3ce5de178" +dependencies = [ + "winapi", +] + +[[package]] +name = "winapi-x86_64-pc-windows-gnu" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" + +[[package]] +name = "windows-sys" +version = "0.42.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a3e1820f08b8513f676f7ab6c1f99ff312fb97b553d30ff4dd86f9f15728aa7" +dependencies = [ + "windows_aarch64_gnullvm 0.42.2", + "windows_aarch64_msvc 0.42.2", + "windows_i686_gnu 0.42.2", + "windows_i686_msvc 0.42.2", + "windows_x86_64_gnu 0.42.2", + "windows_x86_64_gnullvm 0.42.2", + "windows_x86_64_msvc 0.42.2", +] + +[[package]] +name = "windows-sys" +version = "0.45.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75283be5efb2831d37ea142365f009c02ec203cd29a3ebecbc093d52315b66d0" +dependencies = [ + "windows-targets 0.42.2", +] + +[[package]] +name = "windows-sys" +version = "0.48.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "677d2418bec65e3338edb076e806bc1ec15693c5d0104683f2efe857f61056a9" +dependencies = [ + "windows-targets 0.48.0", +] + +[[package]] +name = "windows-targets" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e5180c00cd44c9b1c88adb3693291f1cd93605ded80c250a75d472756b4d071" +dependencies = [ + "windows_aarch64_gnullvm 0.42.2", + "windows_aarch64_msvc 0.42.2", + "windows_i686_gnu 0.42.2", + "windows_i686_msvc 0.42.2", + "windows_x86_64_gnu 0.42.2", + "windows_x86_64_gnullvm 0.42.2", + "windows_x86_64_msvc 0.42.2", +] + +[[package]] +name = "windows-targets" +version = "0.48.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7b1eb6f0cd7c80c79759c929114ef071b87354ce476d9d94271031c0497adfd5" +dependencies = [ + "windows_aarch64_gnullvm 0.48.0", + "windows_aarch64_msvc 0.48.0", + "windows_i686_gnu 0.48.0", + "windows_i686_msvc 0.48.0", + "windows_x86_64_gnu 0.48.0", + "windows_x86_64_gnullvm 0.48.0", + "windows_x86_64_msvc 0.48.0", +] + +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "597a5118570b68bc08d8d59125332c54f1ba9d9adeedeef5b99b02ba2b0698f8" + +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.48.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "91ae572e1b79dba883e0d315474df7305d12f569b400fcf90581b06062f7e1bc" + +[[package]] +name = "windows_aarch64_msvc" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e08e8864a60f06ef0d0ff4ba04124db8b0fb3be5776a5cd47641e942e58c4d43" + +[[package]] +name = "windows_aarch64_msvc" +version = "0.48.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b2ef27e0d7bdfcfc7b868b317c1d32c641a6fe4629c171b8928c7b08d98d7cf3" + +[[package]] +name = "windows_i686_gnu" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c61d927d8da41da96a81f029489353e68739737d3beca43145c8afec9a31a84f" + +[[package]] +name = "windows_i686_gnu" +version = "0.48.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "622a1962a7db830d6fd0a69683c80a18fda201879f0f447f065a3b7467daa241" + +[[package]] +name = "windows_i686_msvc" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "44d840b6ec649f480a41c8d80f9c65108b92d89345dd94027bfe06ac444d1060" + +[[package]] +name = "windows_i686_msvc" +version = "0.48.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4542c6e364ce21bf45d69fdd2a8e455fa38d316158cfd43b3ac1c5b1b19f8e00" + +[[package]] +name = "windows_x86_64_gnu" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8de912b8b8feb55c064867cf047dda097f92d51efad5b491dfb98f6bbb70cb36" + +[[package]] +name = "windows_x86_64_gnu" +version = "0.48.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ca2b8a661f7628cbd23440e50b05d705db3686f894fc9580820623656af974b1" + +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "26d41b46a36d453748aedef1486d5c7a85db22e56aff34643984ea85514e94a3" + +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.48.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7896dbc1f41e08872e9d5e8f8baa8fdd2677f29468c4e156210174edc7f7b953" + +[[package]] +name = "windows_x86_64_msvc" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9aec5da331524158c6d1a4ac0ab1541149c0b9505fde06423b02f5ef0106b9f0" + +[[package]] +name = "windows_x86_64_msvc" +version = "0.48.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1a515f5799fe4961cb532f983ce2b23082366b898e52ffbce459c86f67c8378a" + +[[package]] +name = "winnow" +version = "0.4.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ca0ace3845f0d96209f0375e6d367e3eb87eb65d27d445bdc9f1843a26f39448" +dependencies = [ + "memchr", +] + +[[package]] +name = "winreg" +version = "0.50.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "524e57b2c537c0f9b1e69f1965311ec12182b4122e45035b1508cd24d2adadb1" +dependencies = [ + "cfg-if", + "windows-sys 0.48.0", +] + +[[package]] +name = "ws_stream_wasm" +version = "0.7.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7999f5f4217fe3818726b66257a4475f71e74ffd190776ad053fa159e50737f5" +dependencies = [ + "async_io_stream", + "futures", + "js-sys", + "log", + "pharos", + "rustc_version 0.4.0", + "send_wrapper 0.6.0", + "thiserror", + "wasm-bindgen", + "wasm-bindgen-futures", + "web-sys", +] + +[[package]] +name = "wyz" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "05f360fc0b24296329c78fda852a1e9ae82de9cf7b27dae4b7f62f118f77b9ed" +dependencies = [ + "tap", +] + +[[package]] +name = "xattr" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6d1526bbe5aaeb5eb06885f4d987bcdfa5e23187055de9b83fe00156a821fabc" +dependencies = [ + "libc", +] + +[[package]] +name = "yansi" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09041cd90cf85f7f8b2df60c646f853b7f535ce68f85244eb6731cf89fa498ec" + +[[package]] +name = "zeroize" +version = "1.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2a0956f1ba7c7909bfb66c2e9e4124ab6f6482560f6628b5aaeba39207c9aad9" +dependencies = [ + "zeroize_derive", +] + +[[package]] +name = "zeroize_derive" +version = "1.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ce36e65b0d2999d2aafac989fb249189a141aee1f53c612c1f37d72631959f69" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.22", +] + +[[package]] +name = "zip" +version = "0.6.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "760394e246e4c28189f19d488c058bf16f564016aefac5d32bb1f3b51d5e9261" +dependencies = [ + "aes", + "byteorder", + "bzip2", + "constant_time_eq 0.1.5", + "crc32fast", + "crossbeam-utils", + "flate2", + "hmac", + "pbkdf2 0.11.0", + "sha1", + "time", + "zstd", +] + +[[package]] +name = "zstd" +version = "0.11.2+zstd.1.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "20cc960326ece64f010d2d2107537f26dc589a6573a316bd5b1dba685fa5fde4" +dependencies = [ + "zstd-safe", +] + +[[package]] +name = "zstd-safe" +version = "5.0.2+zstd.1.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d2a5585e04f9eea4b2a3d1eca508c4dee9592a89ef6f450c11719da0726f4db" +dependencies = [ + "libc", + "zstd-sys", +] + +[[package]] +name = "zstd-sys" +version = "2.0.9+zstd.1.5.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9e16efa8a874a0481a574084d34cc26fdb3b99627480f785888deb6386506656" +dependencies = [ + "cc", + "pkg-config", +] diff --git a/mnist_ezkl/Cargo.toml b/mnist_ezkl/Cargo.toml new file mode 100644 index 0000000..950850f --- /dev/null +++ b/mnist_ezkl/Cargo.toml @@ -0,0 +1,168 @@ +[package] +name = "ezkl" +version = "0.0.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[lib] +# Name to be imported within python +# Example: import ezkl +name = "ezkl" +crate-type = ["cdylib", "rlib"] + + +[dependencies] +halo2_gadgets = { git = "https://github.com/zkonduit/halo2", branch= "ac/lookup-modularity" } +halo2_proofs = { git = "https://github.com/zkonduit/halo2", branch= "ac/lookup-modularity" } +halo2curves = { version = "0.1.0" } +rand = { version = "0.8", default_features = false } +itertools = { version = "0.10.3", default_features = false } +clap = { version = "4.3.3", features = ["derive"]} +serde = { version = "1.0.126", features = ["derive"], optional = true } +serde_json = { version = "1.0.97", default_features = false, features = ["float_roundtrip", "raw_value"], optional = true } +log = { version = "0.4.17", default_features = false, optional = true } +thiserror = { version = "1.0.38", default_features = false } +hex = { version = "0.4.3", default_features = false } +halo2_wrong_ecc = { git = "https://github.com/zkonduit/halo2wrong", branch = "ac/chunked-mv-lookup", package = "ecc" } +snark-verifier = { git = "https://github.com/zkonduit/snark-verifier", branch = "ac/chunked-mv-lookup", features=["derive_serde"]} +halo2_solidity_verifier = { git = "https://github.com/alexander-camuto/halo2-solidity-verifier", branch= "ac/lookup-modularity" } +rayon = { version = "1.7.0", default_features = false } +bincode = { version = "1.3.3", default_features = false } +ark-std = { version = "^0.3.0", default-features = false } +unzip-n = "0.1.2" +num = "0.4.1" +tch = "0.14.0" +halo2 = "0.0.0" + +# evm related deps +[target.'cfg(not(target_arch = "wasm32"))'.dependencies] +ethers = { version = "2.0.7", default_features = false, features = ["ethers-solc"] } +indicatif = {version = "0.17.5", features = ["rayon"]} +gag = { version = "1.0.0", default_features = false} +instant = { version = "0.1" } +reqwest = { version = "0.11.14", default-features = false, features = ["default-tls", "multipart", "stream"] } +openssl = { version = "0.10.55", features = ["vendored"] } +postgres = "0.19.5" +pg_bigdecimal = "0.1.5" +lazy_static = "1.4.0" +colored_json = { version = "3.0.1", default_features = false, optional = true} +plotters = { version = "0.3.0", default_features = false, optional = true } +regex = { version = "1", default_features = false } +tokio = { version = "1.26.0", default_features = false, features = ["macros", "rt"] } +tokio-util = { version = "0.7.9", features = ["codec"] } +pyo3 = { version = "0.18.3", features = ["extension-module", "abi3-py37", "macros"], default_features = false, optional = true } +pyo3-asyncio = { version = "0.18.0", features = ["attributes", "tokio-runtime"], default_features = false, optional = true } +pyo3-log = { version = "0.8.1", default_features = false, optional = true } +tract-onnx = { git = "https://github.com/sonos/tract/", rev= "ee98004a2d8d7851da7b9fce954b2a7a7181eccb", default_features = false, optional = true } +tabled = { version = "0.12.0", optional = true } + + +[target.'cfg(not(all(target_arch = "wasm32", target_os = "unknown")))'.dependencies] +colored = { version = "2.0.0", default_features = false, optional = true} +env_logger = { version = "0.10.0", default_features = false, optional = true} + + +[target.'cfg(target_arch = "wasm32")'.dependencies] +getrandom = { version = "0.2.8", features = ["js"] } +instant = { version = "0.1", features = [ "wasm-bindgen", "inaccurate" ] } + +[target.'cfg(all(target_arch = "wasm32", target_os = "unknown"))'.dependencies] +wasm-bindgen-rayon = { version = "1.0", optional=true } +wasm-bindgen-test = "0.3.34" +serde-wasm-bindgen = "0.4" +wasm-bindgen = { version = "0.2.81", features = ["serde-serialize"]} +console_error_panic_hook = "0.1.7" +wasm-bindgen-console-logger = "0.1.1" + + +[dev-dependencies] +criterion = {version = "0.3", features = ["html_reports"]} +tempfile = "3.3.0" +lazy_static = "1.4.0" +mnist = "0.5" +seq-macro = "0.3.1" +test-case = "2.2.2" +tempdir = "0.3.7" +shellexpand = "3.1.0" +benchy = "0.1.1" + +[target.wasm32-unknown-unknown] +runner = 'wasm-bindgen-test-runner' + + +# [[bench]] +# name = "accum_dot" +# harness = false + + +# [[bench]] +# name = "accum_sum" +# harness = false + +# [[bench]] +# name = "pairwise_add" +# harness = false + + +# [[bench]] +# name = "pairwise_pow" +# harness = false + +# [[bench]] +# name = "poseidon" +# harness = false + +# [[bench]] +# name = "elgamal" +# harness = false + + +# [[bench]] +# name = "accum_einsum_matmul" +# harness = false + + +# [[bench]] +# name = "accum_conv" +# harness = false + + +# [[bench]] +# name = "accum_sumpool" +# harness = false + + +# [[bench]] +# name = "relu" +# harness = false + +# [[bench]] +# name = "accum_matmul_relu" +# harness = false + + +# [[bench]] +# name = "accum_matmul_relu_overflow" +# harness = false + +[[bench]] +name = "bench" +harness = false + +[[bin]] +name = "ezkl" +test = false +bench = false +required-features = ["ezkl"] + +[features] +web = ["wasm-bindgen-rayon"] +default = ["ezkl", "mv-lookup"] +render = ["halo2_proofs/dev-graph", "plotters"] +onnx = ["dep:tract-onnx"] +python-bindings = ["pyo3", "pyo3-log", "pyo3-asyncio"] +ezkl = ["onnx", "serde", "serde_json", "log", "colored", "env_logger", "tabled/color", "colored_json", "halo2_proofs/circuit-params"] +mv-lookup = ["halo2_proofs/mv-lookup", "snark-verifier/mv-lookup", "halo2_solidity_verifier/mv-lookup"] +det-prove = [] +icicle = ["halo2_proofs/icicle_gpu"] diff --git a/mnist_ezkl/Pipfile b/mnist_ezkl/Pipfile new file mode 100644 index 0000000..0757494 --- /dev/null +++ b/mnist_ezkl/Pipfile @@ -0,0 +1,11 @@ +[[source]] +url = "https://pypi.org/simple" +verify_ssl = true +name = "pypi" + +[packages] + +[dev-packages] + +[requires] +python_version = "3.11" diff --git a/mnist_ezkl/abis/DataAttestation.json b/mnist_ezkl/abis/DataAttestation.json new file mode 100644 index 0000000..8a90dcd --- /dev/null +++ b/mnist_ezkl/abis/DataAttestation.json @@ -0,0 +1,167 @@ +[ + { + "inputs": [ + { + "internalType": "address[]", + "name": "_contractAddresses", + "type": "address[]" + }, + { + "internalType": "bytes[][]", + "name": "_callData", + "type": "bytes[][]" + }, + { + "internalType": "uint256[][]", + "name": "_decimals", + "type": "uint256[][]" + }, + { + "internalType": "uint256[]", + "name": "_scales", + "type": "uint256[]" + }, + { + "internalType": "uint8", + "name": "_instanceOffset", + "type": "uint8" + }, + { + "internalType": "address", + "name": "_admin", + "type": "address" + } + ], + "stateMutability": "nonpayable", + "type": "constructor" + }, + { + "inputs": [ + { + "internalType": "uint256", + "name": "", + "type": "uint256" + } + ], + "name": "accountCalls", + "outputs": [ + { + "internalType": "address", + "name": "contractAddress", + "type": "address" + }, + { + "internalType": "uint256", + "name": "callCount", + "type": "uint256" + } + ], + "stateMutability": "view", + "type": "function" + }, + { + "inputs": [], + "name": "admin", + "outputs": [ + { + "internalType": "address", + "name": "", + "type": "address" + } + ], + "stateMutability": "view", + "type": "function" + }, + { + "inputs": [], + "name": "instanceOffset", + "outputs": [ + { + "internalType": "uint8", + "name": "", + "type": "uint8" + } + ], + "stateMutability": "view", + "type": "function" + }, + { + "inputs": [ + { + "internalType": "uint256", + "name": "", + "type": "uint256" + } + ], + "name": "scales", + "outputs": [ + { + "internalType": "uint256", + "name": "", + "type": "uint256" + } + ], + "stateMutability": "view", + "type": "function" + }, + { + "inputs": [ + { + "internalType": "address[]", + "name": "_contractAddresses", + "type": "address[]" + }, + { + "internalType": "bytes[][]", + "name": "_callData", + "type": "bytes[][]" + }, + { + "internalType": "uint256[][]", + "name": "_decimals", + "type": "uint256[][]" + } + ], + "name": "updateAccountCalls", + "outputs": [], + "stateMutability": "nonpayable", + "type": "function" + }, + { + "inputs": [ + { + "internalType": "address", + "name": "_admin", + "type": "address" + } + ], + "name": "updateAdmin", + "outputs": [], + "stateMutability": "nonpayable", + "type": "function" + }, + { + "inputs": [ + { + "internalType": "address", + "name": "verifier", + "type": "address" + }, + { + "internalType": "bytes", + "name": "encoded", + "type": "bytes" + } + ], + "name": "verifyWithDataAttestation", + "outputs": [ + { + "internalType": "bool", + "name": "", + "type": "bool" + } + ], + "stateMutability": "view", + "type": "function" + } +] \ No newline at end of file diff --git a/mnist_ezkl/abis/QuantizeData.json b/mnist_ezkl/abis/QuantizeData.json new file mode 100644 index 0000000..25e58a4 --- /dev/null +++ b/mnist_ezkl/abis/QuantizeData.json @@ -0,0 +1,50 @@ +[ + { + "inputs": [ + { + "internalType": "bytes[]", + "name": "data", + "type": "bytes[]" + }, + { + "internalType": "uint256[]", + "name": "decimals", + "type": "uint256[]" + }, + { + "internalType": "uint256[]", + "name": "scales", + "type": "uint256[]" + } + ], + "name": "quantize_data", + "outputs": [ + { + "internalType": "int128[]", + "name": "quantized_data", + "type": "int128[]" + } + ], + "stateMutability": "pure", + "type": "function" + }, + { + "inputs": [ + { + "internalType": "int128[]", + "name": "quantized_data", + "type": "int128[]" + } + ], + "name": "to_field_element", + "outputs": [ + { + "internalType": "uint256[]", + "name": "output", + "type": "uint256[]" + } + ], + "stateMutability": "pure", + "type": "function" + } +] \ No newline at end of file diff --git a/mnist_ezkl/abis/TestReads.json b/mnist_ezkl/abis/TestReads.json new file mode 100644 index 0000000..b31342d --- /dev/null +++ b/mnist_ezkl/abis/TestReads.json @@ -0,0 +1,32 @@ +[ + { + "inputs": [ + { + "internalType": "int256[]", + "name": "_numbers", + "type": "int256[]" + } + ], + "stateMutability": "nonpayable", + "type": "constructor" + }, + { + "inputs": [ + { + "internalType": "uint256", + "name": "", + "type": "uint256" + } + ], + "name": "arr", + "outputs": [ + { + "internalType": "int256", + "name": "", + "type": "int256" + } + ], + "stateMutability": "view", + "type": "function" + } +] \ No newline at end of file diff --git a/mnist_ezkl/benches/bench.rs b/mnist_ezkl/benches/bench.rs new file mode 100644 index 0000000..18db864 --- /dev/null +++ b/mnist_ezkl/benches/bench.rs @@ -0,0 +1,348 @@ +use ezkl::pfsys; +use ezkl::pfsys::TranscriptType; +use ezkl::pfsys::create_proof_circuit_kzg; +use ezkl::pfsys::srs::gen_srs; +use mnist::{Mnist, MnistBuilder}; +use std::{collections::HashMap}; +use tch::{Tensor, Kind, Device, vision, Scalar}; +use benchy::BenchmarkRun; +use ezkl::tensor::ValTensor; +use halo2curves::bn256::Fr; +use halo2_proofs::circuit::Value; +use std::marker::PhantomData; +use halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme; +use halo2_wrong_ecc::halo2::circuit::{Chip, Layouter, SimpleFloorPlanner}; +use halo2_wrong_ecc::halo2::plonk::{Advice, Circuit, Column, ConstraintSystem, Error, Expression, Selector}; +use halo2_wrong_ecc::halo2::poly::Rotation; +use ezkl::circuit::CheckMode; +use ezkl::pfsys::create_keys; +use halo2curves::bn256::Bn256; +const IMAGE_SIZE: usize = 28 * 28; + + + +struct MNISTModel { + l1: Linear, + l2: Linear, +} + +impl MNISTModel { + fn new (mem: &mut Memory) -> MNISTModel { + let l1 = Linear::new(mem, 784, 128); + let l2 = Linear::new(mem, 128, 10); + Self { + l1: l1, + l2: l2, + } + } +} + +impl Compute for MNISTModel { + fn forward (&self, mem: &Memory, input: &Tensor) -> Tensor { + let mut o = self.l1.forward(mem, input); + o = o.relu(); + o = self.l2.forward(mem, &o); + o + } +} + +fn main() { + let (x, y) = load_mnist(); + + let mut m = Memory::new(); + let mnist_model = MNISTModel::new(&mut m); + train(&mut m, &x, &y, &mnist_model, 100, 128, cross_entropy, 0.3); + let out = mnist_model.forward(&m, &x); + println!("Training Accuracy: {}", accuracy(&y, &out)); +} + +trait Compute { + fn forward (&self, mem: &Memory, input: &Tensor) -> Tensor; +} + +struct Linear { + params: HashMap, +} + +impl Linear { + fn new (mem: &mut Memory, ninputs: i64, noutputs: i64) -> Self { + let mut p = HashMap::new(); + p.insert("W".to_string(), mem.new_push(&[ninputs,noutputs], true)); + p.insert("b".to_string(), mem.new_push(&[1, noutputs], true)); + + Self { + params: p, + } + } +} + +impl Compute for Linear { + fn forward (&self, mem: &Memory, input: &Tensor) -> Tensor { + let w = mem.get(self.params.get(&"W".to_string()).unwrap()); + let b = mem.get(self.params.get(&"b".to_string()).unwrap()); + input.matmul(w) + b + } +} + +fn mse(target: &Tensor, pred: &Tensor) -> Tensor { + (target - pred).square().mean(Kind::Float) +} + +fn cross_entropy (target: &Tensor, pred: &Tensor) -> Tensor { + let loss = pred.log_softmax(-1, Kind::Float).nll_loss(target); + loss +} + +struct Memory { + size: usize, + values: Vec, +} + +impl Memory { + + fn new() -> Self { + let v = Vec::new(); + Self {size: 0, + values: v} + } + + fn push (&mut self, value: Tensor) -> usize { + self.values.push(value); + self.size += 1; + self.size-1 + } + + fn new_push (&mut self, size: &[i64], requires_grad: bool) -> usize { + let t = Tensor::randn(size, (Kind::Float, Device::Cpu)).requires_grad_(requires_grad); + self.push(t) + } + + fn get (&self, addr: &usize) -> &Tensor { + &self.values[*addr] + } + + fn apply_grads_sgd(&mut self, learning_rate: f32) { + let mut g = Tensor::new(); + self.values + .iter_mut() + .for_each(|t| { + if t.requires_grad() { + g = t.grad(); + t.set_data(&(t.data() - learning_rate*&g)); + t.zero_grad(); + } + }); + } + + fn apply_grads_sgd_momentum(&mut self, learning_rate: f32) { + let mut g: Tensor = Tensor::new(); + let mut velocity: Vec= Tensor::zeros(&[self.size as i64], (Kind::Float, Device::Cpu)).split(1, 0); + let mut vcounter = 0; + const BETA:f32 = 0.9; + + self.values + .iter_mut() + .for_each(|t| { + if t.requires_grad() { + g = t.grad(); + velocity[vcounter] = BETA * &velocity[vcounter] + (1.0 - BETA) * &g; + t.set_data(&(t.data() - learning_rate * &velocity[vcounter])); + t.zero_grad(); + } + vcounter += 1; + }); + } +} + +fn train(mem: &mut Memory, x: &Tensor, y: &Tensor, model: &dyn Compute, epochs: i64, batch_size: i64, errfunc: F, learning_rate: f32) + where F: Fn(&Tensor, &Tensor)-> Tensor + { + let mut error = Tensor::from(0.0); + let mut batch_error = Tensor::from(0.0); + let mut pred = Tensor::from(0.0); + for epoch in 0..epochs { + batch_error = Tensor::from(0.0); + for (batchx, batchy) in get_batches(&x, &y, batch_size, true) { + pred = model.forward(mem, &batchx); + error = errfunc(&batchy, &pred); + batch_error += error.detach(); + error.backward(); + mem.apply_grads_sgd_momentum(learning_rate); + } + println!("Epoch: {:?} Error: {:?}", epoch, batch_error/batch_size); + } +} + +fn get_batches(x: &Tensor, y: &Tensor, batch_size: i64, shuffle: bool) -> impl Iterator { + let num_rows = x.size()[0]; + let num_batches = (num_rows + batch_size - 1) / batch_size; + + let indices = if shuffle { + Tensor::randperm(num_rows as i64, (Kind::Int64, Device::Cpu)) + } else + { + let rng = (0..num_rows).collect::>(); + Tensor::from_slice(&rng) + }; + let x = x.index_select(0, &indices); + let y = y.index_select(0, &indices); + + (0..num_batches).map(move |i| { + let start = i * batch_size; + let end = (start + batch_size).min(num_rows); + let batchx: Tensor = x.narrow(0, start, end - start); + let batchy: Tensor = y.narrow(0, start, end - start); + (batchx, batchy) + }) +} + + +fn load_mnist() -> (Tensor, Tensor) { + let data = MnistBuilder::new() + .label_format_digit() + .training_set_length(1000) + .validation_set_length(1000) + .finalize(); + + let train_images = Tensor::from_slice(&data.trn_img); + let val_images = Tensor::from_slice(&data.trn_lbl); + let x = train_images; + let y = val_images; + (x, y) +} + +fn accuracy(target: &Tensor, pred: &Tensor) -> f64 { + let yhat = pred.argmax(1,true).squeeze(); + let eq = target.eq_tensor(&yhat); + let accuracy: f64 = (eq.sum(Kind::Int64) / target.size()[0]).double_value(&[]).into(); + accuracy +} + +struct LinearCircuit { + weights: Vec, + biases: Vec, + inputs: Vec, + outputs: Vec, +} + +impl Circuit for LinearCircuit { + fn configure(meta: &mut ConstraintSystem) -> Self::Config { + let weights_col = meta.advice_column(); + let biases_col = meta.advice_column(); + let inputs_col = meta.advice_column(); + let outputs_col = meta.advice_column(); + + let s_mul = meta.selector(); + + let config = LinearConfig { + weights_col, + biases_col, + inputs_col, + outputs_col, + s_mul, + }; + + meta.enable_equality(config.weights_col.into()); + meta.enable_equality(config.biases_col.into()); + meta.enable_equality(config.inputs_col.into()); + meta.enable_equality(config.outputs_col.into()); + + config + } + + fn synthesize(&self, cs: &mut impl Layouter, config: Self::Config) -> Result<(), Error> { + for i in 0..self.weights.len() { + let weight = self.weights[i]; + let bias = self.biases[i]; + let input = self.inputs[i]; + let output = self.outputs[i]; + + cs.assign_advice(|| "weight", config.weights_col, 0, || Ok(Fr::from(weight)))?; + cs.assign_advice(|| "bias", config.biases_col, 0, || Ok(Fr::from(bias)))?; + cs.assign_advice(|| "input", config.inputs_col, 0, || Ok(Fr::from(input)))?; + cs.assign_advice(|| "output", config.outputs_col, 0, || Ok(Fr::from(output)))?; + + config.s_mul.enable(cs, 0); + } + + Ok(()) + } +} + +struct LinearConfig { + weights_col: Column, + biases_col: Column, + inputs_col: Column, + outputs_col: Column, + s_mul: Selector, +} + + +#[benchy::benchmark] +fn run_mnist(bench: &mut BenchmarkRun) { + let mnist = MnistBuilder::new() + .label_format_digit() + .finalize(); + + let (x, y) = load_mnist(); + + let mut m = Memory::new(); + let mnist_model = MNISTModel::new(&mut m); + train(&mut m, &x, &y, &mnist_model, 100, 128, cross_entropy, 0.3); + let out = mnist_model.forward(&m, &x); + +let weights_tensor = weights.view([-1]); +let mut weights: Vec = vec![0.0; weights_tensor.numel() as usize]; +weights_tensor.copy_to(&mut weights, false); + +let biases_tensor = biases.view([-1]); +let mut biases: Vec = vec![0.0; biases_tensor.numel() as usize]; +biases_tensor.copy_to(&mut biases, false); + let inputs = x.view([-1]); + let output_tensor = mnist_model.forward(&m, &x); + let outputs: Vec = output_tensor.view([-1]).iter::().unwrap().collect(); + + let params = gen_srs::>(14 as u32); + + let circuit = LinearCircuit { + weights: weights.clone(), // Fill in with your weights + biases: biases.clone(), // Fill in with your biases + inputs: inputs.clone(), // Fill in with your inputs + outputs: outputs.clone(), // Fill in with your outputs + }; + + // Use MNIST data for the tensors + let image_tensor = Tensor::from_slice(&mnist.trn_img); + let label_tensor = Tensor::from_slice(&mnist.trn_lbl); + image_tensor.reshape(&[IMAGE_SIZE as i64, 1]); + label_tensor.reshape(&[1 as i64, 1]); + + + bench.run(|| { + create_keys::, Fr, LinearCircuit>(&circuit, ¶ms) + .unwrap(); + }); + + let pk = + create_keys::, Fr, LinearCircuit>(&circuit, ¶ms).unwrap(); + // Generate the proof + let prover = create_proof_circuit_kzg( + circuit.clone(), + ¶ms, + None, + &pk, + TranscriptType::EVM, + halo2_proofs::poly::kzg::strategy::SingleStrategy::new(¶ms), + // use safe mode to verify that the proof is correct + CheckMode::SAFE, + None, + ); + + // Verify the proof + let strategy = halo2_proofs::poly::kzg::strategy::SingleStrategy::new(params.verifier_params()); + let vk = pk.get_vk(); + // let result = verify_proof_circuit_kzg(params.verifier_params(), proof, vk, strategy); + + // assert!(result.is_ok()); + +} diff --git a/mnist_ezkl/contracts/AttestData.sol b/mnist_ezkl/contracts/AttestData.sol new file mode 100644 index 0000000..e542075 --- /dev/null +++ b/mnist_ezkl/contracts/AttestData.sol @@ -0,0 +1,270 @@ +// SPDX-License-Identifier: MIT +pragma solidity ^0.8.20; +import './LoadInstances.sol'; + +// This contract serves as a Data Attestation Verifier for the EZKL model. +// It is designed to read and attest to instances of proofs generated from a specified circuit. +// It is particularly constructed to read only int256 data from specified on-chain contracts' view functions. + +// Overview of the contract functionality: +// 1. Initialization: Through the constructor, it sets up the contract calls that the EZKL model will read from. +// 2. Data Quantization: Quantizes the returned data into a scaled fixed-point representation. See the `quantizeData` method for details. +// 3. Static Calls: Makes static calls to fetch data from other contracts. See the `staticCall` method. +// 4. Field Element Conversion: The fixed-point representation is then converted into a field element modulo P using the `toFieldElement` method. +// 5. Data Attestation: The `attestData` method validates that the public instances match the data fetched and processed by the contract. +// 6. Proof Verification: The `verifyWithDataAttestation` method parses the instances out of the encoded calldata and calls the `attestData` method to validate the public instances, +// then calls the `verifyProof` method to verify the proof on the verifier. + +contract DataAttestation is LoadInstances { + /** + * @notice Struct used to make view only calls to accounts to fetch the data that EZKL reads from. + * @param the address of the account to make calls to + * @param the abi encoded function calls to make to the `contractAddress` + */ + struct AccountCall { + address contractAddress; + mapping(uint256 => bytes) callData; + mapping(uint256 => uint256) decimals; + uint callCount; + } + AccountCall[] public accountCalls; + + uint[] public scales; + + address public admin; + + /** + * @notice EZKL P value + * @dev In order to prevent the verifier from accepting two version of the same pubInput, n and the quantity (n + P), where n + P <= 2^256, we require that all instances are stricly less than P. a + * @dev The reason for this is that the assmebly code of the verifier performs all arithmetic operations modulo P and as a consequence can't distinguish between n and n + P. + */ + uint256 constant ORDER = uint256(0x30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001); + + uint256 constant INPUT_CALLS = 0; + + uint256 constant OUTPUT_CALLS = 0; + + uint8 public instanceOffset; + + /** + * @dev Initialize the contract with account calls the EZKL model will read from. + * @param _contractAddresses - The calls to all the contracts EZKL reads storage from. + * @param _callData - The abi encoded function calls to make to the `contractAddress` that EZKL reads storage from. + */ + constructor( + address[] memory _contractAddresses, + bytes[][] memory _callData, + uint256[][] memory _decimals, + uint[] memory _scales, + uint8 _instanceOffset, + address _admin + ) { + admin = _admin; + for (uint i; i < _scales.length; i++) { + scales.push(1 << _scales[i]); + } + populateAccountCalls(_contractAddresses, _callData, _decimals); + instanceOffset = _instanceOffset; + } + + function updateAdmin(address _admin) external { + require(msg.sender == admin, "Only admin can update admin"); + if(_admin == address(0)) { + revert(); + } + admin = _admin; + } + + function updateAccountCalls( + address[] memory _contractAddresses, + bytes[][] memory _callData, + uint256[][] memory _decimals + ) external { + require(msg.sender == admin, "Only admin can update instanceOffset"); + populateAccountCalls(_contractAddresses, _callData, _decimals); + } + + function populateAccountCalls( + address[] memory _contractAddresses, + bytes[][] memory _callData, + uint256[][] memory _decimals + ) internal { + require( + _contractAddresses.length == _callData.length && + accountCalls.length == _contractAddresses.length, + "Invalid input length" + ); + require( + _decimals.length == _contractAddresses.length, + "Invalid number of decimals" + ); + // fill in the accountCalls storage array + uint counter = 0; + for (uint256 i = 0; i < _contractAddresses.length; i++) { + AccountCall storage accountCall = accountCalls[i]; + accountCall.contractAddress = _contractAddresses[i]; + accountCall.callCount = _callData[i].length; + for (uint256 j = 0; j < _callData[i].length; j++) { + accountCall.callData[j] = _callData[i][j]; + accountCall.decimals[j] = 10 ** _decimals[i][j]; + } + // count the total number of storage reads across all of the accounts + counter += _callData[i].length; + } + require(counter == INPUT_CALLS + OUTPUT_CALLS, "Invalid number of calls"); + } + + function mulDiv( + uint256 x, + uint256 y, + uint256 denominator + ) internal pure returns (uint256 result) { + unchecked { + uint256 prod0; + uint256 prod1; + assembly { + let mm := mulmod(x, y, not(0)) + prod0 := mul(x, y) + prod1 := sub(sub(mm, prod0), lt(mm, prod0)) + } + + if (prod1 == 0) { + return prod0 / denominator; + } + + require(denominator > prod1, "Math: mulDiv overflow"); + + uint256 remainder; + assembly { + remainder := mulmod(x, y, denominator) + prod1 := sub(prod1, gt(remainder, prod0)) + prod0 := sub(prod0, remainder) + } + + uint256 twos = denominator & (~denominator + 1); + assembly { + denominator := div(denominator, twos) + prod0 := div(prod0, twos) + twos := add(div(sub(0, twos), twos), 1) + } + + prod0 |= prod1 * twos; + + uint256 inverse = (3 * denominator) ^ 2; + + inverse *= 2 - denominator * inverse; + inverse *= 2 - denominator * inverse; + inverse *= 2 - denominator * inverse; + inverse *= 2 - denominator * inverse; + inverse *= 2 - denominator * inverse; + inverse *= 2 - denominator * inverse; + + result = prod0 * inverse; + return result; + } + } + /** + * @dev Quantize the data returned from the account calls to the scale used by the EZKL model. + * @param data - The data returned from the account calls. + * @param decimals - The number of decimals the data returned from the account calls has (for floating point representation). + * @param scale - The scale used to convert the floating point value into a fixed point value. + */ + function quantizeData( + bytes memory data, + uint256 decimals, + uint256 scale + ) internal pure returns (int256 quantized_data) { + int x = abi.decode(data, (int256)); + bool neg = x < 0; + if (neg) x = -x; + uint output = mulDiv(uint256(x), scale, decimals); + if (mulmod(uint256(x), scale, decimals) * 2 >= decimals) { + output += 1; + } + quantized_data = neg ? -int256(output): int256(output); + } + /** + * @dev Make a static call to the account to fetch the data that EZKL reads from. + * @param target - The address of the account to make calls to. + * @param data - The abi encoded function calls to make to the `contractAddress` that EZKL reads storage from. + * @return The data returned from the account calls. (Must come from either a view or pure function. Will throw an error otherwise) + */ + function staticCall( + address target, + bytes memory data + ) internal view returns (bytes memory) { + (bool success, bytes memory returndata) = target.staticcall(data); + if (success) { + if (returndata.length == 0) { + require( + target.code.length > 0, + "Address: call to non-contract" + ); + } + return returndata; + } else { + revert("Address: low-level call failed"); + } + } + /** + * @dev Convert the fixed point quantized data into a field element. + * @param x - The quantized data. + * @return field_element - The field element. + */ + function toFieldElement(int256 x) internal pure returns (uint256 field_element) { + // The casting down to uint256 is safe because the order is about 2^254, and the value + // of x ranges of -2^127 to 2^127, so x + int(ORDER) is always positive. + return uint256(x + int(ORDER)) % ORDER; + } + + /** + * @dev Make the account calls to fetch the data that EZKL reads from and attest to the data. + * @param instances - The public instances to the proof (the data in the proof that publicly accessible to the verifier). + */ + function attestData(uint256[] memory instances) internal view { + require( + instances.length >= INPUT_CALLS + OUTPUT_CALLS, + "Invalid public inputs length" + ); + uint256 _accountCount = accountCalls.length; + uint counter = 0; + for (uint8 i = 0; i < _accountCount; ++i) { + address account = accountCalls[i].contractAddress; + for (uint8 j = 0; j < accountCalls[i].callCount; j++) { + bytes memory returnData = staticCall( + account, + accountCalls[i].callData[j] + ); + uint256 scale = scales[counter]; + int256 quantized_data = quantizeData( + returnData, + accountCalls[i].decimals[j], + scale + ); + uint256 field_element = toFieldElement(quantized_data); + require( + field_element == instances[counter + instanceOffset], + "Public input does not match" + ); + counter++; + } + } + } + + + function verifyWithDataAttestation( + address verifier, + bytes calldata encoded + ) public view returns (bool) { + require(verifier.code.length > 0,"Address: call to non-contract"); + attestData(getInstancesCalldata(encoded)); + // static call the verifier contract to verify the proof + (bool success, bytes memory returndata) = verifier.staticcall(encoded); + + if (success) { + return abi.decode(returndata, (bool)); + } else { + revert("low-level call to verifier failed"); + } + } +} diff --git a/mnist_ezkl/contracts/LoadInstances.sol b/mnist_ezkl/contracts/LoadInstances.sol new file mode 100644 index 0000000..41fe829 --- /dev/null +++ b/mnist_ezkl/contracts/LoadInstances.sol @@ -0,0 +1,92 @@ +// SPDX-License-Identifier: MIT +pragma solidity ^0.8.20; +contract LoadInstances { + /** + * @dev Parse the instances array from the Halo2Verifier encoded calldata. + * @notice must pass encoded bytes from memory + * @param encoded - verifier calldata + */ + function getInstancesMemory( + bytes memory encoded + ) internal pure returns (uint256[] memory instances) { + bytes4 funcSig; + uint256 instances_offset; + uint256 instances_length; + assembly { + // fetch function sig. Either `verifyProof(bytes,uint256[])` or `verifyProof(address,bytes,uint256[])` + funcSig := mload(add(encoded, 0x20)) + + // Fetch instances offset which is 4 + 32 + 32 bytes away from + // start of encoded for `verifyProof(bytes,uint256[])`, + // and 4 + 32 + 32 +32 away for `verifyProof(address,bytes,uint256[])` + + instances_offset := mload( + add(encoded, add(0x44, mul(0x20, eq(funcSig, 0xaf83a18d)))) + ) + + instances_length := mload(add(add(encoded, 0x24), instances_offset)) + } + instances = new uint256[](instances_length); // Allocate memory for the instances array. + assembly { + // Now instances points to the start of the array data + // (right after the length field). + for { + let i := 0x20 + } lt(i, add(mul(instances_length, 0x20), 0x20)) { + i := add(i, 0x20) + } { + mstore( + add(instances, i), + mload(add(add(encoded, add(i, 0x24)), instances_offset)) + ) + } + } + } + /** + * @dev Parse the instances array from the Halo2Verifier encoded calldata. + * @notice must pass encoded bytes from calldata + * @param encoded - verifier calldata + */ + function getInstancesCalldata( + bytes calldata encoded + ) internal pure returns (uint256[] memory instances) { + bytes4 funcSig; + uint256 instances_offset; + uint256 instances_length; + assembly { + // fetch function sig. Either `verifyProof(bytes,uint256[])` or `verifyProof(address,bytes,uint256[])` + funcSig := calldataload(encoded.offset) + + // Fetch instances offset which is 4 + 32 + 32 bytes away from + // start of encoded for `verifyProof(bytes,uint256[])`, + // and 4 + 32 + 32 +32 away for `verifyProof(address,bytes,uint256[])` + + instances_offset := calldataload( + add( + encoded.offset, + add(0x24, mul(0x20, eq(funcSig, 0xaf83a18d))) + ) + ) + + instances_length := calldataload(add(add(encoded.offset, 0x04), instances_offset)) + } + instances = new uint256[](instances_length); // Allocate memory for the instances array. + assembly{ + // Now instances points to the start of the array data + // (right after the length field). + + for { + let i := 0x20 + } lt(i, add(mul(instances_length, 0x20), 0x20)) { + i := add(i, 0x20) + } { + mstore( + add(instances, i), + calldataload( + add(add(encoded.offset, add(i, 0x04)), instances_offset) + ) + ) + } + } + } +} \ No newline at end of file diff --git a/mnist_ezkl/contracts/QuantizeData.sol b/mnist_ezkl/contracts/QuantizeData.sol new file mode 100644 index 0000000..ce832b4 --- /dev/null +++ b/mnist_ezkl/contracts/QuantizeData.sol @@ -0,0 +1,135 @@ +// SPDX-License-Identifier: GPL-3.0 + +pragma solidity ^0.8.17; + +contract QuantizeData { + /** + * @notice EZKL P value + * @dev In order to prevent the verifier from accepting two version of the same instance, n and the quantity (n + P), where n + P <= 2^256, we require that all instances are stricly less than P. a + * @dev The reason for this is that the assmebly code of the verifier performs all arithmetic operations modulo P and as a consequence can't distinguish between n and n + P. + */ + uint256 constant ORDER = + uint256( + 0x30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001 + ); + + /** + * @notice Calculates floor(x * y / denominator) with full precision. Throws if result overflows a uint256 or denominator == 0 + * @dev Original credit to Remco Bloemen under MIT license (https://xn--2-umb.com/21/muldiv) + * with further edits by Uniswap Labs also under MIT license. + */ + function mulDiv( + uint256 x, + uint256 y, + uint256 denominator + ) internal pure returns (uint256 result) { + unchecked { + // 512-bit multiply [prod1 prod0] = x * y. Compute the product mod 2^256 and mod 2^256 - 1, then use + // use the Chinese Remainder Theorem to reconstruct the 512 bit result. The result is stored in two 256 + // variables such that product = prod1 * 2^256 + prod0. + uint256 prod0; // Least significant 256 bits of the product + uint256 prod1; // Most significant 256 bits of the product + assembly { + let mm := mulmod(x, y, not(0)) + prod0 := mul(x, y) + prod1 := sub(sub(mm, prod0), lt(mm, prod0)) + } + + // Handle non-overflow cases, 256 by 256 division. + if (prod1 == 0) { + // Solidity will revert if denominator == 0, unlike the div opcode on its own. + // The surrounding unchecked block does not change this fact. + // See https://docs.soliditylang.org/en/latest/control-structures.html#checked-or-unchecked-arithmetic. + return prod0 / denominator; + } + + // Make sure the result is less than 2^256. Also prevents denominator == 0. + require(denominator > prod1, "Math: mulDiv overflow"); + + /////////////////////////////////////////////// + // 512 by 256 division. + /////////////////////////////////////////////// + + // Make division exact by subtracting the remainder from [prod1 prod0]. + uint256 remainder; + assembly { + // Compute remainder using mulmod. + remainder := mulmod(x, y, denominator) + + // Subtract 256 bit number from 512 bit number. + prod1 := sub(prod1, gt(remainder, prod0)) + prod0 := sub(prod0, remainder) + } + + // Factor powers of two out of denominator and compute largest power of two divisor of denominator. Always >= 1. + // See https://cs.stackexchange.com/q/138556/92363. + + // Does not overflow because the denominator cannot be zero at this stage in the function. + uint256 twos = denominator & (~denominator + 1); + assembly { + // Divide denominator by twos. + denominator := div(denominator, twos) + + // Divide [prod1 prod0] by twos. + prod0 := div(prod0, twos) + + // Flip twos such that it is 2^256 / twos. If twos is zero, then it becomes one. + twos := add(div(sub(0, twos), twos), 1) + } + + // Shift in bits from prod1 into prod0. + prod0 |= prod1 * twos; + + // Invert denominator mod 2^256. Now that denominator is an odd number, it has an inverse modulo 2^256 such + // that denominator * inv = 1 mod 2^256. Compute the inverse by starting with a seed that is correct for + // four bits. That is, denominator * inv = 1 mod 2^4. + uint256 inverse = (3 * denominator) ^ 2; + + // Use the Newton-Raphson iteration to improve the precision. Thanks to Hensel's lifting lemma, this also works + // in modular arithmetic, doubling the correct bits in each step. + inverse *= 2 - denominator * inverse; // inverse mod 2^8 + inverse *= 2 - denominator * inverse; // inverse mod 2^16 + inverse *= 2 - denominator * inverse; // inverse mod 2^32 + inverse *= 2 - denominator * inverse; // inverse mod 2^64 + inverse *= 2 - denominator * inverse; // inverse mod 2^128 + inverse *= 2 - denominator * inverse; // inverse mod 2^256 + + // Because the division is now exact we can divide by multiplying with the modular inverse of denominator. + // This will give us the correct result modulo 2^256. Since the preconditions guarantee that the outcome is + // less than 2^256, this is the final result. We don't need to compute the high bits of the result and prod1 + // is no longer required. + result = prod0 * inverse; + return result; + } + } + + function quantize_data( + bytes[] memory data, + uint256[] memory decimals, + uint256[] memory scales + ) external pure returns (int256[] memory quantized_data) { + quantized_data = new int256[](data.length); + for (uint i; i < data.length; i++) { + int x = abi.decode(data[i], (int256)); + bool neg = x < 0; + if (neg) x = -x; + uint denom = 10 ** decimals[i]; + uint scale = 1 << scales[i]; + uint output = mulDiv(uint256(x), scale, denom); + if (mulmod(uint256(x), scale, denom) * 2 >= denom) { + output += 1; + } + + quantized_data[i] = neg ? -int256(output) : int256(output); + } + } + + function to_field_element( + int128[] memory quantized_data + ) public pure returns (uint256[] memory output) { + output = new uint256[](quantized_data.length); + for (uint i; i < quantized_data.length; i++) { + output[i] = uint256(quantized_data[i] + int(ORDER)) % ORDER; + } + } +} diff --git a/mnist_ezkl/contracts/TestReads.sol b/mnist_ezkl/contracts/TestReads.sol new file mode 100644 index 0000000..2a263f0 --- /dev/null +++ b/mnist_ezkl/contracts/TestReads.sol @@ -0,0 +1,12 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.17; + +contract TestReads { + int[] public arr; + + constructor(int256[] memory _numbers) { + for (uint256 i = 0; i < _numbers.length; i++) { + arr.push(_numbers[i]); + } + } +} diff --git a/mnist_ezkl/src/bin/ezkl.rs b/mnist_ezkl/src/bin/ezkl.rs new file mode 100644 index 0000000..b754245 --- /dev/null +++ b/mnist_ezkl/src/bin/ezkl.rs @@ -0,0 +1,78 @@ +// ignore file if compiling for wasm + +#[cfg(not(target_arch = "wasm32"))] +use clap::Parser; +#[cfg(not(target_arch = "wasm32"))] +use colored_json::ToColoredJson; +#[cfg(not(target_arch = "wasm32"))] +use ezkl::commands::Cli; +#[cfg(not(target_arch = "wasm32"))] +use ezkl::execute::run; +#[cfg(not(target_arch = "wasm32"))] +use ezkl::logger::init_logger; +#[cfg(not(target_arch = "wasm32"))] +use log::{error, info}; +#[cfg(not(target_arch = "wasm32"))] +use rand::prelude::SliceRandom; +#[cfg(not(target_arch = "wasm32"))] +#[cfg(feature = "icicle")] +use std::env; +#[cfg(not(target_arch = "wasm32"))] +use std::error::Error; + +#[tokio::main(flavor = "current_thread")] +#[cfg(not(target_arch = "wasm32"))] +pub async fn main() -> Result<(), Box> { + let args = Cli::parse(); + init_logger(); + banner(); + #[cfg(feature = "icicle")] + if env::var("ENABLE_ICICLE_GPU").is_ok() { + info!("Running with ICICLE GPU"); + } else { + info!("Running with CPU"); + } + info!("command: \n {}", &args.as_json()?.to_colored_json_auto()?); + let res = run(args.command).await; + match &res { + Ok(_) => info!("succeeded"), + Err(e) => error!("failed: {}", e), + }; + res +} + +#[cfg(target_arch = "wasm32")] +pub fn main() {} + +#[cfg(not(target_arch = "wasm32"))] +fn banner() { + let ell: Vec<&str> = vec![ + "for Neural Networks", + "Linear Algebra", + "for Layers", + "for the Laconic", + "Learning", + "for Liberty", + "for the Lyrical", + ]; + info!( + "{}", + format!( + " + + ███████╗███████╗██╗ ██╗██╗ + ██╔════╝╚══███╔╝██║ ██╔╝██║ + █████╗ ███╔╝ █████╔╝ ██║ + ██╔══╝ ███╔╝ ██╔═██╗ ██║ + ███████╗███████╗██║ ██╗███████╗ + ╚══════╝╚══════╝╚═╝ ╚═╝╚══════╝ + + ----------------------------------------------------------- + Easy Zero Knowledge {}. + ----------------------------------------------------------- + + ", + ell.choose(&mut rand::thread_rng()).unwrap() + ) + ); +} diff --git a/mnist_ezkl/src/circuit/mod.rs b/mnist_ezkl/src/circuit/mod.rs new file mode 100644 index 0000000..6153e66 --- /dev/null +++ b/mnist_ezkl/src/circuit/mod.rs @@ -0,0 +1,18 @@ +/// +pub mod modules; + +/// +pub mod table; + +/// +pub mod utils; + +/// +pub mod ops; + +pub use ops::chip::*; +pub use ops::*; + +/// Tests +#[cfg(test)] +mod tests; diff --git a/mnist_ezkl/src/circuit/modules/elgamal.rs b/mnist_ezkl/src/circuit/modules/elgamal.rs new file mode 100644 index 0000000..c93c1a1 --- /dev/null +++ b/mnist_ezkl/src/circuit/modules/elgamal.rs @@ -0,0 +1,881 @@ +/* +An easy-to-use implementation of the ElGamal Encryption in the form of a Halo2 Chip. +Huge thank you to https://github.com/timoftime/ for providing the inspiration and launching point for this <3 . +*/ + +/// +mod add_chip; + +use crate::circuit::modules::poseidon::spec::PoseidonSpec; +use crate::tensor::{Tensor, ValTensor, ValType}; +use add_chip::{AddChip, AddConfig, AddInstruction}; +use ark_std::rand::{CryptoRng, RngCore}; +use halo2_proofs::arithmetic::Field; +use halo2_proofs::circuit::{AssignedCell, Chip, Layouter, Value}; +use halo2_proofs::plonk; +use halo2_proofs::plonk::{Advice, Column, ConstraintSystem, Error, Instance}; +use halo2_wrong_ecc::integer::rns::{Common, Integer, Rns}; +use halo2_wrong_ecc::maingate::{ + MainGate, MainGateConfig, RangeChip, RangeConfig, RangeInstructions, RegionCtx, +}; +use halo2_wrong_ecc::{AssignedPoint, BaseFieldEccChip, EccConfig}; +use halo2curves::bn256::{Fq, Fr, G1Affine, G1}; +use halo2curves::group::cofactor::CofactorCurveAffine; +use halo2curves::group::{Curve, Group}; +use halo2curves::CurveAffine; +use serde::{Deserialize, Serialize}; +use std::ops::{Mul, MulAssign}; +use std::rc::Rc; +use std::vec; + +use super::poseidon::{PoseidonChip, PoseidonConfig}; +use super::Module; + +// Absolute offsets for public inputs. +const C1_X: usize = 0; +const C1_Y: usize = 1; +const SK_H: usize = 2; +const C2_H: usize = 3; + +/// +const NUMBER_OF_LIMBS: usize = 4; +const BIT_LEN_LIMB: usize = 64; +/// The number of instance columns used by the ElGamal circuit. +pub const NUM_INSTANCE_COLUMNS: usize = 1; + +/// The poseidon hash width. +pub const POSEIDON_WIDTH: usize = 2; +/// The poseidon hash rate. +pub const POSEIDON_RATE: usize = 1; +/// The poseidon len +pub const POSEIDON_LEN: usize = 2; + +#[derive(Debug)] +/// A chip implementing ElGamal encryption. +pub struct ElGamalChip { + /// The configuration for this chip. + pub config: ElGamalConfig, + /// The ECC chip. + ecc: BaseFieldEccChip, + /// The Poseidon hash chip. + poseidon: PoseidonChip, + /// The addition chip. + add: AddChip, +} + +#[derive(Debug, Clone)] +/// Configuration for the ElGamal chip. +pub struct ElGamalConfig { + main_gate_config: MainGateConfig, + range_config: RangeConfig, + poseidon_config: PoseidonConfig, + add_config: AddConfig, + plaintext_col: Column, + /// The column used for the instance. + pub instance: Column, + /// The config has been initialized. + pub initialized: bool, +} + +impl ElGamalConfig { + fn config_range(&self, layouter: &mut impl Layouter) -> Result<(), Error> { + let range_chip = RangeChip::::new(self.range_config.clone()); + range_chip.load_table(layouter)?; + Ok(()) + } + + fn ecc_chip_config(&self) -> EccConfig { + EccConfig::new(self.range_config.clone(), self.main_gate_config.clone()) + } +} + +impl Chip for ElGamalChip { + type Config = ElGamalConfig; + type Loaded = (); + + fn config(&self) -> &Self::Config { + &self.config + } + + fn loaded(&self) -> &Self::Loaded { + &() + } +} + +impl ElGamalChip { + /// Create a new `ElGamalChip`. + pub fn new(p: ElGamalConfig) -> ElGamalChip { + ElGamalChip { + ecc: BaseFieldEccChip::new(p.ecc_chip_config()), + poseidon: PoseidonChip::new(p.poseidon_config.clone()), + add: AddChip::construct(p.add_config.clone()), + config: p, + } + } + + /// Configure the chip. + fn configure(meta: &mut ConstraintSystem) -> ElGamalConfig { + let main_gate_config = MainGate::::configure(meta); + let advices = main_gate_config.advices(); + let main_fixed_columns = main_gate_config.fixed(); + let instance = main_gate_config.instance(); + + let rc_a = main_fixed_columns[3..5].try_into().unwrap(); + let rc_b = [meta.fixed_column(), meta.fixed_column()]; + + meta.enable_constant(rc_b[0]); + + let rns = Rns::::construct(); + + let overflow_bit_lens = rns.overflow_lengths(); + let composition_bit_lens = vec![BIT_LEN_LIMB / NUMBER_OF_LIMBS]; + + let range_config = RangeChip::::configure( + meta, + &main_gate_config, + composition_bit_lens, + overflow_bit_lens, + ); + + let poseidon_config = + PoseidonChip::::configure_with_cols( + meta, + advices[0], + rc_a, + rc_b, + advices[1..3].try_into().unwrap(), + None, + ); + + let add_config = AddChip::configure(meta, advices[0], advices[1], advices[2]); + + let plaintext_col = advices[1]; + + ElGamalConfig { + poseidon_config, + main_gate_config, + range_config, + add_config, + plaintext_col, + instance, + initialized: false, + } + } +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +/// The variables used in the ElGamal circuit. +pub struct ElGamalVariables { + /// The randomness used in the encryption. + pub r: Fr, + /// The public key. + pub pk: G1Affine, + /// The secret key. + pub sk: Fr, + /// The window size used in the ECC chip. + pub window_size: usize, + /// The auxiliary generator used in the ECC chip. + pub aux_generator: G1Affine, +} + +impl Default for ElGamalVariables { + fn default() -> Self { + Self { + r: Fr::zero(), + pk: G1Affine::identity(), + sk: Fr::zero(), + window_size: 4, + aux_generator: G1Affine::identity(), + } + } +} + +impl ElGamalVariables { + /// Create new variables. + pub fn new(r: Fr, pk: G1Affine, sk: Fr, window_size: usize, aux_generator: G1Affine) -> Self { + Self { + r, + pk, + sk, + window_size, + aux_generator, + } + } + + /// Generate random variables. + pub fn gen_random(mut rng: &mut R) -> Self { + // get a random element from the scalar field + let sk = Fr::random(&mut rng); + + // compute secret_key*generator to derive the public key + // With BN256, we create the private key from a random number. This is a private key value (sk + // and a public key mapped to the G2 curve:: pk=sk.G2 + let mut pk = G1::generator(); + pk.mul_assign(sk); + + Self { + r: Fr::random(&mut rng), + pk: pk.to_affine(), + sk, + window_size: 4, + aux_generator: ::CurveExt::random(rng).to_affine(), + } + } +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +/// The cipher returned from the ElGamal encryption. +pub struct ElGamalCipher { + /// c1 := r*G + pub c1: G1, + /// c2 := m*s + pub c2: Vec, +} + +#[derive(Debug, Clone)] +/// A gadget implementing ElGamal encryption. +pub struct ElGamalGadget { + /// The configuration for this gadget. + pub config: ElGamalConfig, + /// The variables used in this gadget. + variables: Option, +} + +impl ElGamalGadget { + /// Load the variables into the gadget. + pub fn load_variables(&mut self, variables: ElGamalVariables) { + self.variables = Some(variables); + } + + fn rns() -> Rc> { + let rns = Rns::::construct(); + Rc::new(rns) + } + + /// Encrypt a message using the public key. + pub fn encrypt(pk: G1Affine, msg: Vec, r: Fr) -> ElGamalCipher { + let g = G1Affine::generator(); + let c1 = g.mul(&r); + + let coords = pk.mul(&r).to_affine().coordinates().unwrap(); + + let x = Integer::from_fe(*coords.x(), Self::rns()); + let y = Integer::from_fe(*coords.y(), Self::rns()); + + let dh = PoseidonChip::::run( + [x.native(), y.native()].to_vec(), + ) + .unwrap()[0][0]; + + let mut c2 = vec![]; + + for m in &msg { + c2.push(m + dh); + } + + ElGamalCipher { c1, c2 } + } + + /// Hash the msssage to be used as a public input. + pub fn hash_encrypted_msg(msg: Vec) -> Fr { + PoseidonChip::::run(msg).unwrap() + [0][0] + } + + /// Hash the secret key to be used as a public input. + pub fn hash_sk(sk: Fr) -> Fr { + PoseidonChip::::run(vec![sk, sk]) + .unwrap()[0][0] + } + + /// Decrypt a ciphertext using the secret key. + pub fn decrypt(cipher: &ElGamalCipher, sk: Fr) -> Vec { + let c1 = cipher.c1; + let c2 = cipher.c2.clone(); + + let s = c1.mul(sk).to_affine().coordinates().unwrap(); + + let x = Integer::from_fe(*s.x(), Self::rns()); + let y = Integer::from_fe(*s.y(), Self::rns()); + + let dh = PoseidonChip::::run( + [x.native(), y.native()].to_vec(), + ) + .unwrap()[0][0]; + + let mut msg = vec![]; + for encrypted_m in &c2 { + msg.push(encrypted_m - dh); + } + + msg + } + + /// Get the public inputs for the circuit. + pub fn get_instances(cipher: &ElGamalCipher, sk_hash: Fr) -> Vec> { + let mut c1_and_sk = cipher + .c1 + .to_affine() + .coordinates() + .map(|c| { + let x = Integer::from_fe(*c.x(), Self::rns()); + let y = Integer::from_fe(*c.y(), Self::rns()); + + vec![x.native(), y.native()] + }) + .unwrap(); + + c1_and_sk.push(sk_hash); + + c1_and_sk.push(Self::hash_encrypted_msg(cipher.c2.clone())); + + vec![c1_and_sk] + } + + pub(crate) fn verify_encrypted_msg_hash( + &self, + mut layouter: impl Layouter, + config: &ElGamalConfig, + encrypted_msg: &[AssignedCell], + ) -> Result, plonk::Error> { + let chip = ElGamalChip::new(config.clone()); + + // compute dh = poseidon_hash(randomness*pk) + let encrypted_msg_hash = { + let poseidon_message = + Tensor::from(encrypted_msg.iter().map(|m| ValType::from(m.clone()))); + + chip.poseidon.layout( + &mut layouter.namespace(|| "Poseidon hash (encrypted_msg)"), + &[poseidon_message.into()], + 0, + )? + }; + + match &encrypted_msg_hash + .get_inner_tensor() + .map_err(|_| plonk::Error::Synthesis)?[0] + { + ValType::PrevAssigned(v) => Ok(v.clone()), + _ => { + log::error!("poseidon hash should be an assigned value"); + Err(plonk::Error::Synthesis) + } + } + } + + /// Hash the secret key to be used as a public input. + pub(crate) fn verify_sk_hash( + &self, + mut layouter: impl Layouter, + config: &ElGamalConfig, + sk: &AssignedCell, + ) -> Result, plonk::Error> { + let chip = ElGamalChip::new(config.clone()); + + // compute dh = poseidon_hash(randomness*pk) + let sk_hash = { + let poseidon_message = + Tensor::from([ValType::from(sk.clone()), ValType::from(sk.clone())].into_iter()); + + chip.poseidon.layout( + &mut layouter.namespace(|| "Poseidon hash (sk)"), + &[poseidon_message.into()], + 0, + )? + }; + + let sk_hash = match &sk_hash + .get_inner_tensor() + .map_err(|_| plonk::Error::Synthesis)?[0] + { + ValType::PrevAssigned(v) => v.clone(), + _ => { + log::error!("poseidon hash should be an assigned value"); + return Err(plonk::Error::Synthesis); + } + }; + + Ok(sk_hash) + } + + pub(crate) fn verify_secret( + &self, + mut layouter: impl Layouter, + config: &ElGamalConfig, + sk: &AssignedCell, + ) -> Result<[AssignedPoint; 2], plonk::Error> { + let mut chip = ElGamalChip::new(config.clone()); + + let g = G1Affine::generator(); + + let variables = match self.variables { + Some(ref variables) => variables, + None => { + log::error!("variables not loaded"); + return Err(plonk::Error::Synthesis); + } + }; + + // compute s = randomness*pk + let s = variables.pk.mul(variables.r).to_affine(); + let c1 = g.mul(variables.r).to_affine(); + + layouter.assign_region( + || "obtain_s", + |region| { + let offset = 0; + let ctx = &mut RegionCtx::new(region, offset); + + chip.ecc + .assign_aux_generator(ctx, Value::known(variables.aux_generator))?; + chip.ecc.assign_aux(ctx, variables.window_size, 1)?; + + let s = chip.ecc.assign_point(ctx, Value::known(s)).unwrap(); + // compute c1 = randomness*generator + let c1 = chip.ecc.assign_point(ctx, Value::known(c1)).unwrap(); + + let s_from_sk = chip.ecc.mul(ctx, &c1, sk, variables.window_size).unwrap(); + + chip.ecc.assert_equal(ctx, &s, &s_from_sk)?; + + Ok([s, c1]) + }, + ) + } + + pub(crate) fn verify_encryption( + &self, + mut layouter: impl Layouter, + config: &ElGamalConfig, + m: &AssignedCell, + s: &AssignedPoint, + ) -> Result, plonk::Error> { + let chip = ElGamalChip::new(config.clone()); + + // compute dh = poseidon_hash(randomness*pk) + let dh = { + let poseidon_message = Tensor::from( + [ + ValType::from(s.x().native().clone()), + ValType::from(s.y().native().clone()), + ] + .into_iter(), + ); + + chip.poseidon.layout( + &mut layouter.namespace(|| "Poseidon hasher"), + &[poseidon_message.into()], + 0, + )? + }; + + let dh = match &dh.get_inner_tensor().map_err(|_| plonk::Error::Synthesis)?[0] { + ValType::PrevAssigned(v) => v.clone(), + _ => { + log::error!("poseidon hash should be an assigned value"); + return Err(plonk::Error::Synthesis); + } + }; + + // compute c2 = poseidon_hash(nk, rho) + psi. + let c2 = chip.add.add( + layouter.namespace(|| "c2 = poseidon_hash(randomness*pk) + m"), + &dh, + m, + )?; + + Ok(c2) + } +} + +impl Module for ElGamalGadget { + type Config = ElGamalConfig; + type InputAssignments = (Vec>, AssignedCell); + type RunInputs = (Vec, ElGamalVariables); + type Params = (); + + fn new(config: Self::Config) -> Self { + Self { + config, + variables: None, + } + } + + fn configure(meta: &mut ConstraintSystem, _: Self::Params) -> Self::Config { + ElGamalChip::configure(meta) + } + + fn name(&self) -> &'static str { + "ElGamal" + } + + fn instance_increment_input(&self) -> Vec { + // in order + // 1. c1, sk_hash, c2_hash + vec![4] + } + + fn run(input: Self::RunInputs) -> Result>, Box> { + let start_time = instant::Instant::now(); + + let (input, var) = input; + let len = input.len(); + + let cipher = Self::encrypt(var.pk, input, var.r); + // keep 1 empty (maingate instance variable). + let mut public_inputs: Vec> = vec![]; + public_inputs.extend(Self::get_instances(&cipher, Self::hash_sk(var.sk))); + + log::trace!("run (N={:?}) took: {:?}", len, start_time.elapsed()); + + Ok(public_inputs) + } + + fn layout_inputs( + &self, + layouter: &mut impl Layouter, + inputs: &[ValTensor], + ) -> Result { + assert_eq!(inputs.len(), 2); + let message = inputs[0].clone(); + let sk = inputs[1].clone(); + + let start_time = instant::Instant::now(); + let (msg_var, sk_var) = layouter.assign_region( + || "plaintext", + |mut region| { + let msg_var: Result>, Error> = match &message { + ValTensor::Value { inner: v, .. } => v + .iter() + .enumerate() + .map(|(i, value)| match value { + ValType::Value(v) => region.assign_advice( + || format!("load message_{}", i), + self.config.plaintext_col, + i, + || *v, + ), + ValType::PrevAssigned(v) | ValType::AssignedConstant(v, ..) => { + Ok(v.clone()) + } + ValType::Constant(f) => region.assign_advice_from_constant( + || format!("load message_{}", i), + self.config.plaintext_col, + i, + *f, + ), + e => { + log::error!("wrong input type: {:?}", e); + Err(Error::Synthesis) + } + }) + .collect(), + ValTensor::Instance { + dims, + inner: col, + idx, + initial_offset, + .. + } => { + // this should never ever fail + let num_elems = dims[*idx].iter().product::(); + (0..num_elems) + .map(|i| { + region.assign_advice_from_instance( + || "pub input anchor", + *col, + initial_offset + i, + self.config.plaintext_col, + i, + ) + }) + .collect() + } + }; + + let sk = match sk.get_inner_tensor().unwrap()[0] { + ValType::Value(v) => v, + _ => { + log::error!("wrong input type"); + return Err(Error::Synthesis); + } + }; + + let msg_var = msg_var?; + + let sk_var = region.assign_advice( + || "sk", + self.config.plaintext_col, + msg_var.len(), + || sk, + )?; + + Ok((msg_var, sk_var)) + }, + )?; + let duration = start_time.elapsed(); + log::trace!("layout inputs took: {:?}", duration); + + Ok((msg_var, sk_var)) + } + + fn layout( + &self, + layouter: &mut impl Layouter, + inputs: &[ValTensor], + row_offset: usize, + ) -> Result, Error> { + let start_time = instant::Instant::now(); + + // if all equivalent to 0, then we are in the first row of the circuit + if !self.config.initialized { + self.config.config_range(layouter).unwrap(); + } + + let (msg_var, sk_var) = self.layout_inputs(layouter, inputs)?; + + let [s, c1] = self.verify_secret( + layouter.namespace(|| "verify_secret"), + &self.config, + &sk_var, + )?; + + // Force the public input to be the hash of the secret key so that we can ascertain decryption can happen + let sk_hash = self.verify_sk_hash( + layouter.namespace(|| "verify_sk_hash"), + &self.config, + &sk_var, + )?; + + layouter + .constrain_instance( + c1.x().native().cell(), + self.config.instance, + C1_X + row_offset, + ) + .and(layouter.constrain_instance( + c1.y().native().cell(), + self.config.instance, + C1_Y + row_offset, + )) + .and(layouter.constrain_instance( + sk_hash.cell(), + self.config.instance, + SK_H + row_offset, + ))?; + + let c2: Result>, _> = msg_var + .iter() + .map(|m| { + self.verify_encryption( + layouter.namespace(|| "verify_encryption"), + &self.config, + m, + &s, + ) + }) + .collect(); + + let c2 = c2?; + + let c2_hash = self.verify_encrypted_msg_hash( + layouter.namespace(|| "verify_c2_hash"), + &self.config, + &c2, + )?; + + layouter.constrain_instance(c2_hash.cell(), self.config.instance, C2_H + row_offset)?; + + let mut assigned_input: Tensor> = + msg_var.iter().map(|e| ValType::from(e.clone())).into(); + + assigned_input.reshape(inputs[0].dims()).map_err(|e| { + log::error!("reshape failed: {:?}", e); + Error::Synthesis + })?; + + log::trace!( + "layout (N={:?}) took: {:?}", + msg_var.len(), + start_time.elapsed() + ); + + Ok(assigned_input.into()) + } + + fn num_rows(input_len: usize) -> usize { + // this was determined by running the circuit and looking at the number of constraints + // in the test called hash_for_a_range_of_input_sizes, then regressing in python to find the slope + // ```python + // import numpy as np + // x = [1, 2, 3, 512, 513, 514] + // y = [75424, 75592, 75840, 161017, 161913, 162000] + // def fit_above(x, y) : + // x0, y0 = x[0] - 1, y[0] + // x -= x0 + // y -= y0 + // def error_function_2(b, x, y) : + // a = np.min((y - b) / x) + // return np.sum((y - a * x - b)**2) + // b = scipy.optimize.minimize(error_function_2, [0], args=(x, y)).x[0] + // a = np.max((y - b) / x) + // return a, b - a * x0 + y0 + // a, b = fit_above(x, y) + // plt.plot(x, y, 'o') + // plt.plot(x, a*x + b, '-') + // plt.show() + // for (x_i, y_i) in zip(x,y): + // assert y_i <= a*x_i + b + // print(a, b) + // ``` + const NUM_CONSTRAINTS_SLOPE: usize = 196; + const NUM_CONSTRAINTS_INTERCEPT: usize = 75257; + + // check if even or odd + input_len * NUM_CONSTRAINTS_SLOPE + NUM_CONSTRAINTS_INTERCEPT + } +} + +#[cfg(test)] +mod tests { + use crate::circuit::modules::ModulePlanner; + + use super::*; + use ark_std::test_rng; + use halo2_proofs::{dev::MockProver, plonk::Circuit}; + + struct EncryptionCircuit { + message: ValTensor, + variables: ElGamalVariables, + } + + impl Circuit for EncryptionCircuit { + type Config = ElGamalConfig; + type FloorPlanner = ModulePlanner; + type Params = (); + + fn without_witnesses(&self) -> Self { + let empty_val: Vec> = vec![Value::::unknown().into()]; + let message: Tensor> = empty_val.into_iter().into(); + + let variables = ElGamalVariables::default(); + + Self { + message: message.into(), + variables, + } + } + + fn configure(meta: &mut ConstraintSystem) -> ElGamalConfig { + ElGamalGadget::configure(meta, ()) + } + + fn synthesize( + &self, + config: ElGamalConfig, + mut layouter: impl Layouter, + ) -> Result<(), Error> { + let mut chip = ElGamalGadget::new(config); + chip.load_variables(self.variables.clone()); + let sk: Tensor> = + Tensor::new(Some(&[Value::known(self.variables.sk).into()]), &[1]).unwrap(); + chip.layout(&mut layouter, &[self.message.clone(), sk.into()], 0)?; + Ok(()) + } + } + + #[test] + // this is for backwards compatibility with the old format + fn test_variables_serialization_round_trip() { + let mut rng = test_rng(); + + let var = ElGamalVariables::gen_random(&mut rng); + + let mut buf = vec![]; + serde_json::to_writer(&mut buf, &var).unwrap(); + + let var2 = serde_json::from_reader(&buf[..]).unwrap(); + + assert_eq!(var, var2); + } + + #[test] + pub fn test_encrypt_decrypt() { + let mut rng = test_rng(); + + let var = ElGamalVariables::gen_random(&mut rng); + + let mut msg = vec![]; + // + for _ in 0..32 { + msg.push(Fr::random(&mut rng)); + } + + let cipher = ElGamalGadget::encrypt(var.pk, msg.clone(), var.r); + + let decrypted_msg = ElGamalGadget::decrypt(&cipher, var.sk); + + assert_eq!(decrypted_msg, msg); + } + + #[test] + pub fn test_circuit() { + let mut rng = test_rng(); + + let var = ElGamalVariables::gen_random(&mut rng); + + let mut msg = vec![]; + // + for _ in 0..2 { + msg.push(Fr::random(&mut rng)); + } + + let run_inputs = (msg.clone(), var.clone()); + let public_inputs: Vec> = ElGamalGadget::run(run_inputs).unwrap(); + + let message: Tensor> = msg.into_iter().map(|m| Value::known(m).into()).into(); + + let circuit = EncryptionCircuit { + message: message.into(), + variables: var, + }; + + let res = MockProver::run(17, &circuit, public_inputs).unwrap(); + res.assert_satisfied_par(); + } + + #[test] + #[ignore] + pub fn test_circuit_range_of_input_sizes() { + let mut rng = test_rng(); + + #[cfg(not(target_arch = "wasm32"))] + env_logger::init(); + + // + for i in [1, 2, 3, 512, 513, 514, 1024] { + println!("i is {} ----------------------------------------", i); + + let var = ElGamalVariables::gen_random(&mut rng); + let mut msg = vec![]; + for _ in 0..i { + msg.push(Fr::random(&mut rng)); + } + + let run_inputs = (msg.clone(), var.clone()); + let public_inputs: Vec> = ElGamalGadget::run(run_inputs).unwrap(); + + let message: Tensor> = + msg.into_iter().map(|m| Value::known(m).into()).into(); + + let circuit = EncryptionCircuit { + message: message.into(), + variables: var, + }; + + let res = MockProver::run(19, &circuit, public_inputs).unwrap(); + res.assert_satisfied_par(); + } + } +} diff --git a/mnist_ezkl/src/circuit/modules/elgamal/add_chip.rs b/mnist_ezkl/src/circuit/modules/elgamal/add_chip.rs new file mode 100644 index 0000000..ef777cb --- /dev/null +++ b/mnist_ezkl/src/circuit/modules/elgamal/add_chip.rs @@ -0,0 +1,92 @@ +use halo2_proofs::{ + circuit::{AssignedCell, Chip, Layouter}, + plonk::{self, Advice, Column, ConstraintSystem, Constraints, Selector}, + poly::Rotation, +}; +use halo2curves::bn256::Fr; +use halo2curves::ff::PrimeField; + +/// An instruction set for adding two circuit words (field elements). +pub(crate) trait AddInstruction: Chip { + /// Constraints `a + b` and returns the sum. + fn add( + &self, + layouter: impl Layouter, + a: &AssignedCell, + b: &AssignedCell, + ) -> Result, plonk::Error>; +} + +#[derive(Clone, Debug)] +pub struct AddConfig { + a: Column, + b: Column, + c: Column, + q_add: Selector, +} + +/// A chip implementing a single addition constraint `c = a + b` on a single row. +#[derive(Clone, Debug)] +pub struct AddChip { + config: AddConfig, +} + +impl Chip for AddChip { + type Config = AddConfig; + type Loaded = (); + + fn config(&self) -> &Self::Config { + &self.config + } + + fn loaded(&self) -> &Self::Loaded { + &() + } +} + +impl AddChip { + pub fn configure( + meta: &mut ConstraintSystem, + a: Column, + b: Column, + c: Column, + ) -> AddConfig { + let q_add = meta.selector(); + meta.create_gate("Field element addition: c = a + b", |meta| { + let q_add = meta.query_selector(q_add); + let a = meta.query_advice(a, Rotation::cur()); + let b = meta.query_advice(b, Rotation::cur()); + let c = meta.query_advice(c, Rotation::cur()); + + Constraints::with_selector(q_add, Some(a + b - c)) + }); + + AddConfig { a, b, c, q_add } + } + + pub fn construct(config: AddConfig) -> Self { + Self { config } + } +} + +impl AddInstruction for AddChip { + fn add( + &self, + mut layouter: impl Layouter, + a: &AssignedCell, + b: &AssignedCell, + ) -> Result, plonk::Error> { + layouter.assign_region( + || "c = a + b", + |mut region| { + self.config.q_add.enable(&mut region, 0)?; + + a.copy_advice(|| "copy a", &mut region, self.config.a, 0)?; + b.copy_advice(|| "copy b", &mut region, self.config.b, 0)?; + + let scalar_val = a.value().zip(b.value()).map(|(a, b)| a + b); + region.assign_advice(|| "c", self.config.c, 0, || scalar_val) + }, + ) + } +} diff --git a/mnist_ezkl/src/circuit/modules/kzg.rs b/mnist_ezkl/src/circuit/modules/kzg.rs new file mode 100644 index 0000000..fb6157c --- /dev/null +++ b/mnist_ezkl/src/circuit/modules/kzg.rs @@ -0,0 +1,245 @@ +/* +An easy-to-use implementation of the Poseidon Hash in the form of a Halo2 Chip. While the Poseidon Hash function +is already implemented in halo2_gadgets, there is no wrapper chip that makes it easy to use in other circuits. +Thanks to https://github.com/summa-dev/summa-solvency/blob/master/src/chips/poseidon/hash.rs for the inspiration (and also helping us understand how to use this). +*/ + +// This chip adds a set of advice columns to the gadget Chip to store the inputs of the hash +use halo2_proofs::halo2curves::bn256::Fr as Fp; +use halo2_proofs::poly::commitment::{Blind, Params}; +use halo2_proofs::poly::kzg::commitment::ParamsKZG; +use halo2_proofs::{circuit::*, plonk::*}; +use halo2curves::bn256::{Bn256, G1Affine}; +use halo2curves::group::prime::PrimeCurveAffine; +use halo2curves::group::Curve; +use halo2curves::CurveAffine; + +use crate::tensor::{Tensor, ValTensor, ValType, VarTensor}; + +use super::Module; + +/// The number of instance columns used by the KZG hash function +pub const NUM_INSTANCE_COLUMNS: usize = 0; +/// The number of advice columns used by the KZG hash function +pub const NUM_INNER_COLS: usize = 1; + +#[derive(Debug, Clone)] +/// WIDTH, RATE and L are const generics for the struct, which represent the width, rate, and number of inputs for the Poseidon hash function, respectively. +/// This means they are values that are known at compile time and can be used to specialize the implementation of the struct. +/// The actual chip provided by halo2_gadgets is added to the parent Chip. +pub struct KZGConfig { + /// + pub hash_inputs: VarTensor, +} + +type InputAssignments = (); + +/// PoseidonChip is a wrapper around the Pow5Chip that adds a set of advice columns to the gadget Chip to store the inputs of the hash +#[derive(Debug)] +pub struct KZGChip { + config: KZGConfig, +} + +impl KZGChip { + /// Returns the number of inputs to the hash function + pub fn commit( + message: Vec, + degree: u32, + num_unusable_rows: u32, + params: &ParamsKZG, + ) -> Vec { + let k = params.k(); + let domain = halo2_proofs::poly::EvaluationDomain::new(degree, k); + let n = 2_u64.pow(k) - num_unusable_rows as u64; + let num_poly = (message.len() / n as usize) + 1; + let mut poly = vec![domain.empty_lagrange(); num_poly]; + + (0..num_unusable_rows).for_each(|i| { + for p in &mut poly { + p[(n + i as u64) as usize] = Blind::default().0; + } + }); + + for (i, m) in message.iter().enumerate() { + let x = i / (n as usize); + let y = i % (n as usize); + poly[x][y] = *m; + } + + let mut advice_commitments_projective = vec![]; + for a in poly { + advice_commitments_projective.push(params.commit_lagrange(&a, Blind::default())) + } + + let mut advice_commitments = + vec![G1Affine::identity(); advice_commitments_projective.len()]; + ::CurveExt::batch_normalize( + &advice_commitments_projective, + &mut advice_commitments, + ); + advice_commitments + } +} + +impl Module for KZGChip { + type Config = KZGConfig; + type InputAssignments = InputAssignments; + type RunInputs = Vec; + type Params = (usize, usize); + + fn name(&self) -> &'static str { + "KZG" + } + + fn instance_increment_input(&self) -> Vec { + vec![0] + } + + /// Constructs a new PoseidonChip + fn new(config: Self::Config) -> Self { + Self { config } + } + + /// Configuration of the PoseidonChip + fn configure(meta: &mut ConstraintSystem, params: Self::Params) -> Self::Config { + let hash_inputs = VarTensor::new_unblinded_advice(meta, params.0, NUM_INNER_COLS, params.1); + Self::Config { hash_inputs } + } + + fn layout_inputs( + &self, + _: &mut impl Layouter, + _: &[ValTensor], + ) -> Result { + Ok(()) + } + + /// L is the number of inputs to the hash function + /// Takes the cells containing the input values of the hash function and return the cell containing the hash output + /// It uses the pow5_chip to compute the hash + fn layout( + &self, + layouter: &mut impl Layouter, + input: &[ValTensor], + _: usize, + ) -> Result, Error> { + assert_eq!(input.len(), 1); + layouter.assign_region( + || "kzg commit", + |mut region| self.config.hash_inputs.assign(&mut region, 0, &input[0]), + ) + } + + /// + fn run(message: Vec) -> Result>, Box> { + Ok(vec![message]) + } + + fn num_rows(_: usize) -> usize { + 0 + } +} + +#[allow(unused)] +mod tests { + + use crate::circuit::modules::ModulePlanner; + + use super::*; + + use std::marker::PhantomData; + + use halo2_proofs::{ + circuit::{Layouter, SimpleFloorPlanner, Value}, + plonk::{Circuit, ConstraintSystem}, + }; + use halo2curves::ff::Field; + + const K: usize = 8; + const R: usize = 2048; + + struct HashCircuit { + message: ValTensor, + } + + impl Circuit for HashCircuit { + type Config = KZGConfig; + type FloorPlanner = ModulePlanner; + type Params = (); + + fn without_witnesses(&self) -> Self { + let empty_val: Vec> = vec![Value::::unknown().into(); R]; + let message: Tensor> = empty_val.into_iter().into(); + + Self { + message: message.into(), + } + } + + fn configure(meta: &mut ConstraintSystem) -> Self::Config { + let params = (K, R); + KZGChip::configure(meta, params) + } + + fn synthesize( + &self, + config: Self::Config, + mut layouter: impl Layouter, + ) -> Result<(), Error> { + let kzg_chip = KZGChip::new(config); + kzg_chip.layout(&mut layouter, &[self.message.clone()], 0); + + Ok(()) + } + } + + #[test] + #[ignore] + fn kzg_for_a_range_of_input_sizes() { + let rng = rand::rngs::OsRng; + + #[cfg(not(target_arch = "wasm32"))] + env_logger::init(); + + { + let i = 32; + // print a bunch of new lines + println!( + "i is {} -------------------------------------------------", + i + ); + + let message: Vec = (0..i).map(|_| Fp::random(rng)).collect::>(); + + let mut message: Tensor> = + message.into_iter().map(|m| Value::known(m).into()).into(); + + let circuit = HashCircuit { + message: message.into(), + }; + let prover = halo2_proofs::dev::MockProver::run(K as u32, &circuit, vec![]).unwrap(); + + assert_eq!(prover.verify_par(), Ok(())) + } + } + + #[test] + #[ignore] + fn kzg_commit_much_longer_input() { + #[cfg(not(target_arch = "wasm32"))] + env_logger::init(); + + let rng = rand::rngs::OsRng; + + let mut message: Vec = (0..2048).map(|_| Fp::random(rng)).collect::>(); + + let mut message: Tensor> = + message.into_iter().map(|m| Value::known(m).into()).into(); + + let circuit = HashCircuit { + message: message.into(), + }; + let prover = halo2_proofs::dev::MockProver::run(K as u32, &circuit, vec![]).unwrap(); + assert_eq!(prover.verify_par(), Ok(())) + } +} diff --git a/mnist_ezkl/src/circuit/modules/mod.rs b/mnist_ezkl/src/circuit/modules/mod.rs new file mode 100644 index 0000000..9b31df8 --- /dev/null +++ b/mnist_ezkl/src/circuit/modules/mod.rs @@ -0,0 +1,57 @@ +/// +pub mod poseidon; + +/// +pub mod elgamal; + +/// +pub mod kzg; + +/// +pub mod planner; +use halo2_proofs::{ + circuit::Layouter, + plonk::{ConstraintSystem, Error}, +}; +use halo2curves::ff::PrimeField; +pub use planner::*; + +use crate::tensor::{TensorType, ValTensor}; + +/// Module trait used to extend ezkl functionality +pub trait Module { + /// Config + type Config; + /// The return type after an input assignment + type InputAssignments; + /// The inputs used in the run function + type RunInputs; + /// The params used in configure + type Params; + + /// construct new module from config + fn new(config: Self::Config) -> Self; + /// Configure + fn configure(meta: &mut ConstraintSystem, params: Self::Params) -> Self::Config; + /// Name + fn name(&self) -> &'static str; + /// Run the operation the module represents + fn run(input: Self::RunInputs) -> Result>, Box>; + /// Layout inputs + fn layout_inputs( + &self, + layouter: &mut impl Layouter, + input: &[ValTensor], + ) -> Result; + /// Layout + fn layout( + &self, + layouter: &mut impl Layouter, + input: &[ValTensor], + row_offset: usize, + ) -> Result, Error>; + /// Number of instance values the module uses every time it is applied + fn instance_increment_input(&self) -> Vec; + /// Number of rows used by the module + fn num_rows(input_len: usize) -> usize; +} diff --git a/mnist_ezkl/src/circuit/modules/planner.rs b/mnist_ezkl/src/circuit/modules/planner.rs new file mode 100644 index 0000000..2b02914 --- /dev/null +++ b/mnist_ezkl/src/circuit/modules/planner.rs @@ -0,0 +1,545 @@ +use std::cmp; +use std::collections::HashMap; +use std::fmt; +use std::marker::PhantomData; + +use halo2curves::ff::Field; + +use halo2_proofs::{ + circuit::{ + layouter::{RegionColumn, RegionLayouter, RegionShape, SyncDeps, TableLayouter}, + Cell, Layouter, Region, RegionIndex, RegionStart, Table, Value, + }, + plonk::{ + Advice, Any, Assigned, Assignment, Challenge, Circuit, Column, Error, Fixed, FloorPlanner, + Instance, Selector, TableColumn, + }, +}; +use log::{trace, warn}; + +/// A simple [`FloorPlanner`] that performs minimal optimizations. +#[derive(Debug)] +pub struct ModulePlanner; + +impl FloorPlanner for ModulePlanner { + fn synthesize + SyncDeps, C: Circuit>( + cs: &mut CS, + circuit: &C, + config: C::Config, + constants: Vec>, + ) -> Result<(), Error> { + let layouter = ModuleLayouter::new(cs, constants)?; + circuit.synthesize(config, layouter) + } +} +/// +pub type ModuleIdx = usize; +/// +pub type RegionIdx = usize; + +/// A [`Layouter`] for a circuit with multiple modules. +pub struct ModuleLayouter<'a, F: Field, CS: Assignment + 'a> { + cs: &'a mut CS, + constants: Vec>, + /// Stores the starting row for each region. + regions: HashMap>, + /// Stores the starting row for each region. + region_idx: HashMap, + /// Stores the first empty row for each column. + columns: HashMap<(ModuleIdx, RegionColumn), usize>, + /// Stores the table fixed columns. + table_columns: Vec, + _marker: PhantomData, + /// current module + current_module: usize, + /// num_constants + total_constants: usize, +} + +impl<'a, F: Field, CS: Assignment + 'a> fmt::Debug for ModuleLayouter<'a, F, CS> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("ModuleLayouter") + .field("regions", &self.regions) + .field("columns", &self.columns) + .field("total_constants", &self.total_constants) + .finish() + } +} + +impl<'a, F: Field, CS: Assignment> ModuleLayouter<'a, F, CS> { + /// Creates a new module layouter. + pub fn new(cs: &'a mut CS, constants: Vec>) -> Result { + let ret = ModuleLayouter { + cs, + constants, + regions: HashMap::default(), + columns: HashMap::default(), + region_idx: HashMap::default(), + table_columns: vec![], + current_module: 0, + total_constants: 0, + _marker: PhantomData, + }; + Ok(ret) + } + + /// + fn get_constant_col_cartesian_coord( + &self, + linear_coord: usize, + col_size: usize, + ) -> (usize, usize) { + let x = linear_coord / col_size; + let y = linear_coord % col_size; + (x, y) + } +} + +impl<'a, F: Field, CS: Assignment + 'a + SyncDeps> Layouter for ModuleLayouter<'a, F, CS> { + type Root = Self; + + fn assign_region(&mut self, name: N, mut assignment: A) -> Result + where + A: FnMut(Region<'_, F>) -> Result, + N: Fn() -> NR, + NR: Into, + { + // if the name contains the required substring we increment the current module idx + if Into::::into(name()).contains("_enter_module_") { + let name = Into::::into(name()); + let index = match name.split("_enter_module_").last() { + Some(v) => v, + None => { + log::error!("Invalid module name"); + return Err(Error::Synthesis); + } + }; + let index = index.parse::().map_err(|_| { + log::error!("Invalid module name"); + return Error::Synthesis; + })?; + if !self.regions.contains_key(&index) { + warn!("spawning module {}", index) + }; + self.current_module = index; + } + + let region_index = self.region_idx.len(); + self.region_idx.insert(region_index, self.current_module); + + // Get shape of the region. + let mut shape = RegionShape::new(region_index.into()); + { + let region: &mut dyn RegionLayouter = &mut shape; + assignment(region.into())?; + } + + // Modules are stacked horizontally across new columns -- THIS ASSUMES THE MODULES HAVE NON OVERLAPPING COLUMNS. + let region_start = match self.regions.get_mut(&self.current_module) { + Some(v) => { + let mut region_start = 0; + for column in shape.columns().iter() { + region_start = cmp::max( + region_start, + self.columns + .get(&(self.current_module, *column)) + .cloned() + .unwrap_or(0), + ); + } + + v.insert(region_index, region_start.into()); + region_start + } + None => { + let map = HashMap::from([(region_index, 0.into())]); + self.regions.insert(self.current_module, map); + 0 + } + }; + + // Update column usage information. + for column in shape.columns() { + self.columns.insert( + (self.current_module, *column), + region_start + shape.row_count(), + ); + } + + // Assign region cells. + self.cs.enter_region(name); + let mut region = ModuleLayouterRegion::new(self, region_index.into()); + let result = { + let region: &mut dyn RegionLayouter = &mut region; + assignment(region.into()) + }?; + let constants_to_assign = region.constants; + self.cs.exit_region(); + + // Assign constants. For the simple floor planner, we assign constants in order in + // the first `constants` column. + if self.constants.is_empty() { + if !constants_to_assign.is_empty() { + return Err(Error::NotEnoughColumnsForConstants); + } + } else { + for (constant, advice) in constants_to_assign { + // read path from OS env + let (constant_column, y) = crate::graph::GLOBAL_SETTINGS.with(|settings| { + match settings.borrow().as_ref() { + Some(settings) => { + let col_size = settings.available_col_size(); + let (x, y) = self + .get_constant_col_cartesian_coord(self.total_constants, col_size); + (self.constants[x], y) + } + None => (self.constants[0], self.total_constants), + } + }); + + self.cs.assign_fixed( + || format!("Constant({:?})", constant.evaluate()), + constant_column, + y, + || Value::known(constant), + )?; + + let region_module = self.region_idx[&advice.region_index]; + + self.cs.copy( + constant_column.into(), + y, + advice.column, + *self.regions[®ion_module][&advice.region_index] + advice.row_offset, + )?; + self.total_constants += 1; + } + } + + trace!("region {} assigned", region_index); + trace!("total_constants: {:?}", self.total_constants); + let max_row_index = self + .columns + .iter() + .filter(|((module, _), _)| *module == self.current_module) + .map(|(_, row)| *row) + .max() + .unwrap_or(0); + trace!("max_row_index: {:?}", max_row_index); + + Ok(result) + } + + fn assign_table(&mut self, name: N, mut assignment: A) -> Result<(), Error> + where + A: FnMut(Table<'_, F>) -> Result<(), Error>, + N: Fn() -> NR, + NR: Into, + { + // Maintenance hazard: there is near-duplicate code in `v1::AssignmentPass::assign_table`. + // Assign table cells. + self.cs.enter_region(name); + let mut table = + halo2_proofs::circuit::SimpleTableLayouter::new(self.cs, &self.table_columns); + { + let table: &mut dyn TableLayouter = &mut table; + assignment(table.into()) + }?; + let default_and_assigned = table.default_and_assigned; + self.cs.exit_region(); + + // Check that all table columns have the same length `first_unused`, + // and all cells up to that length are assigned. + let first_unused = { + match default_and_assigned + .values() + .map(|(_, assigned)| { + if assigned.iter().all(|b| *b) { + Some(assigned.len()) + } else { + None + } + }) + .reduce(|acc, item| match (acc, item) { + (Some(a), Some(b)) if a == b => Some(a), + _ => None, + }) { + Some(Some(len)) => len, + _ => return Err(Error::Synthesis), // TODO better error + } + }; + + // Record these columns so that we can prevent them from being used again. + for column in default_and_assigned.keys() { + self.table_columns.push(*column); + } + + for (col, (default_val, _)) in default_and_assigned { + // default_val must be Some because we must have assigned + // at least one cell in each column, and in that case we checked + // that all cells up to first_unused were assigned. + let default_val = default_val.ok_or(Error::Synthesis)?; + + self.cs + .fill_from_row(col.inner(), first_unused, default_val)?; + } + + Ok(()) + } + + fn constrain_instance( + &mut self, + cell: Cell, + instance: Column, + row: usize, + ) -> Result<(), Error> { + let module_idx = self.region_idx[&cell.region_index]; + + self.cs.copy( + cell.column, + *self.regions[&module_idx][&cell.region_index] + cell.row_offset, + instance.into(), + row, + ) + } + + fn get_challenge(&self, challenge: Challenge) -> Value { + self.cs.get_challenge(challenge) + } + + fn get_root(&mut self) -> &mut Self::Root { + self + } + + fn push_namespace(&mut self, name_fn: N) + where + NR: Into, + N: FnOnce() -> NR, + { + self.cs.push_namespace(name_fn) + } + + fn pop_namespace(&mut self, gadget_name: Option) { + self.cs.pop_namespace(gadget_name) + } +} + +struct ModuleLayouterRegion<'r, 'a, F: Field, CS: Assignment + 'a> { + layouter: &'r mut ModuleLayouter<'a, F, CS>, + region_index: RegionIndex, + /// Stores the constants to be assigned, and the cells to which they are copied. + constants: Vec<(Assigned, Cell)>, +} + +impl<'r, 'a, F: Field, CS: Assignment + 'a> fmt::Debug for ModuleLayouterRegion<'r, 'a, F, CS> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("ModuleLayouterRegion") + .field("layouter", &self.layouter) + .field("region_index", &self.region_index) + .finish() + } +} + +impl<'r, 'a, F: Field, CS: Assignment + 'a> ModuleLayouterRegion<'r, 'a, F, CS> { + fn new(layouter: &'r mut ModuleLayouter<'a, F, CS>, region_index: RegionIndex) -> Self { + ModuleLayouterRegion { + layouter, + region_index, + constants: vec![], + } + } +} + +impl<'r, 'a, F: Field, CS: Assignment + 'a + SyncDeps> RegionLayouter + for ModuleLayouterRegion<'r, 'a, F, CS> +{ + fn instance_value( + &mut self, + instance: Column, + row: usize, + ) -> Result, Error> { + self.layouter.cs.query_instance(instance, row) + } + + fn enable_selector<'v>( + &'v mut self, + annotation: &'v (dyn Fn() -> String + 'v), + selector: &Selector, + offset: usize, + ) -> Result<(), Error> { + let module_idx = self.layouter.region_idx[&self.region_index]; + self.layouter.cs.enable_selector( + annotation, + selector, + *self.layouter.regions[&module_idx][&self.region_index] + offset, + ) + } + + fn name_column<'v>( + &'v mut self, + annotation: &'v (dyn Fn() -> String + 'v), + column: Column, + ) { + self.layouter.cs.annotate_column(annotation, column); + } + + fn assign_advice<'v>( + &'v mut self, + annotation: &'v (dyn Fn() -> String + 'v), + column: Column, + offset: usize, + to: &'v mut (dyn FnMut() -> Value> + 'v), + ) -> Result { + let module_idx = self.layouter.region_idx[&self.region_index]; + + self.layouter.cs.assign_advice( + annotation, + column, + *self.layouter.regions[&module_idx][&self.region_index] + offset, + to, + )?; + + Ok(Cell { + region_index: self.region_index, + row_offset: offset, + column: column.into(), + }) + } + + fn assign_advice_from_constant<'v>( + &'v mut self, + annotation: &'v (dyn Fn() -> String + 'v), + column: Column, + offset: usize, + constant: Assigned, + ) -> Result { + let advice = + self.assign_advice(annotation, column, offset, &mut || Value::known(constant))?; + self.constrain_constant(advice, constant)?; + + Ok(advice) + } + + fn assign_advice_from_instance<'v>( + &mut self, + annotation: &'v (dyn Fn() -> String + 'v), + instance: Column, + row: usize, + advice: Column, + offset: usize, + ) -> Result<(Cell, Value), Error> { + let value = self.layouter.cs.query_instance(instance, row)?; + + let cell = self.assign_advice(annotation, advice, offset, &mut || value.to_field())?; + let module_idx = self.layouter.region_idx[&cell.region_index]; + + self.layouter.cs.copy( + cell.column, + *self.layouter.regions[&module_idx][&cell.region_index] + cell.row_offset, + instance.into(), + row, + )?; + + Ok((cell, value)) + } + + fn assign_fixed<'v>( + &'v mut self, + annotation: &'v (dyn Fn() -> String + 'v), + column: Column, + offset: usize, + to: &'v mut (dyn FnMut() -> Value> + 'v), + ) -> Result { + let module_idx = self.layouter.region_idx[&self.region_index]; + + self.layouter.cs.assign_fixed( + annotation, + column, + *self.layouter.regions[&module_idx][&self.region_index] + offset, + to, + )?; + + Ok(Cell { + region_index: self.region_index, + row_offset: offset, + column: column.into(), + }) + } + + fn constrain_constant(&mut self, cell: Cell, constant: Assigned) -> Result<(), Error> { + self.constants.push((constant, cell)); + Ok(()) + } + + fn constrain_equal(&mut self, left: Cell, right: Cell) -> Result<(), Error> { + let left_module = self.layouter.region_idx[&left.region_index]; + let right_module = self.layouter.region_idx[&right.region_index]; + + self.layouter.cs.copy( + left.column, + *self.layouter.regions[&left_module][&left.region_index] + left.row_offset, + right.column, + *self.layouter.regions[&right_module][&right.region_index] + right.row_offset, + )?; + + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use halo2curves::pasta::vesta; + + use super::ModulePlanner; + use halo2_proofs::{ + dev::MockProver, + plonk::{Advice, Circuit, Column, Error}, + }; + + #[test] + fn not_enough_columns_for_constants() { + struct MyCircuit {} + + impl Circuit for MyCircuit { + type Config = Column; + type FloorPlanner = ModulePlanner; + type Params = (); + + fn without_witnesses(&self) -> Self { + MyCircuit {} + } + + fn configure( + meta: &mut halo2_proofs::plonk::ConstraintSystem, + ) -> Self::Config { + meta.advice_column() + } + + fn synthesize( + &self, + config: Self::Config, + mut layouter: impl halo2_proofs::circuit::Layouter, + ) -> Result<(), halo2_proofs::plonk::Error> { + layouter.assign_region( + || "assign constant", + |mut region| { + region.assign_advice_from_constant( + || "one", + config, + 0, + vesta::Scalar::one(), + ) + }, + )?; + + Ok(()) + } + } + + let circuit = MyCircuit {}; + assert!(matches!( + MockProver::run(3, &circuit, vec![]).unwrap_err(), + Error::NotEnoughColumnsForConstants, + )); + } +} diff --git a/mnist_ezkl/src/circuit/modules/poseidon.rs b/mnist_ezkl/src/circuit/modules/poseidon.rs new file mode 100644 index 0000000..ca7caba --- /dev/null +++ b/mnist_ezkl/src/circuit/modules/poseidon.rs @@ -0,0 +1,578 @@ +/* +An easy-to-use implementation of the Poseidon Hash in the form of a Halo2 Chip. While the Poseidon Hash function +is already implemented in halo2_gadgets, there is no wrapper chip that makes it easy to use in other circuits. +Thanks to https://github.com/summa-dev/summa-solvency/blob/master/src/chips/poseidon/hash.rs for the inspiration (and also helping us understand how to use this). +*/ + +pub mod poseidon_params; +pub mod spec; + +// This chip adds a set of advice columns to the gadget Chip to store the inputs of the hash +use halo2_gadgets::poseidon::{primitives::*, Hash, Pow5Chip, Pow5Config}; +use halo2_proofs::arithmetic::Field; +use halo2_proofs::halo2curves::bn256::Fr as Fp; +use halo2_proofs::{circuit::*, plonk::*}; +// use rayon::prelude::{IndexedParallelIterator, IntoParallelRefIterator}; +use rayon::prelude::ParallelIterator; +use rayon::slice::ParallelSlice; + +use std::marker::PhantomData; + +use crate::tensor::{Tensor, ValTensor, ValType}; + +use super::Module; + +/// The number of instance columns used by the Poseidon hash function +pub const NUM_INSTANCE_COLUMNS: usize = 1; + +#[derive(Debug, Clone)] +/// WIDTH, RATE and L are const generics for the struct, which represent the width, rate, and number of inputs for the Poseidon hash function, respectively. +/// This means they are values that are known at compile time and can be used to specialize the implementation of the struct. +/// The actual chip provided by halo2_gadgets is added to the parent Chip. +pub struct PoseidonConfig { + /// + pub hash_inputs: Vec>, + /// + pub instance: Option>, + /// + pub pow5_config: Pow5Config, +} + +type InputAssignments = (Vec>, AssignedCell); + +/// PoseidonChip is a wrapper around the Pow5Chip that adds a set of advice columns to the gadget Chip to store the inputs of the hash +#[derive(Debug, Clone)] +pub struct PoseidonChip< + S: Spec + Sync, + const WIDTH: usize, + const RATE: usize, + const L: usize, +> { + config: PoseidonConfig, + _marker: PhantomData, +} + +impl + Sync, const WIDTH: usize, const RATE: usize, const L: usize> + PoseidonChip +{ + /// Creates a new PoseidonChip + pub fn configure_with_cols( + meta: &mut ConstraintSystem, + partial_sbox: Column, + rc_a: [Column; WIDTH], + rc_b: [Column; WIDTH], + hash_inputs: Vec>, + instance: Option>, + ) -> PoseidonConfig { + let pow5_config = Pow5Chip::configure::( + meta, + hash_inputs.clone().try_into().unwrap(), + partial_sbox, + rc_a, + rc_b, + ); + + PoseidonConfig { + pow5_config, + instance, + hash_inputs, + } + } +} + +impl + Sync, const WIDTH: usize, const RATE: usize, const L: usize> + PoseidonChip +{ + /// Configuration of the PoseidonChip + pub fn configure_with_optional_instance( + meta: &mut ConstraintSystem, + instance: Option>, + ) -> PoseidonConfig { + // instantiate the required columns + let hash_inputs = (0..WIDTH).map(|_| meta.advice_column()).collect::>(); + for input in &hash_inputs { + meta.enable_equality(*input); + } + + let partial_sbox = meta.advice_column(); + let rc_a = (0..WIDTH).map(|_| meta.fixed_column()).collect::>(); + let rc_b = (0..WIDTH).map(|_| meta.fixed_column()).collect::>(); + + for input in hash_inputs.iter().take(WIDTH) { + meta.enable_equality(*input); + } + meta.enable_constant(rc_b[0]); + + Self::configure_with_cols( + meta, + partial_sbox, + rc_a.try_into().unwrap(), + rc_b.try_into().unwrap(), + hash_inputs, + instance, + ) + } +} + +impl + Sync, const WIDTH: usize, const RATE: usize, const L: usize> + Module for PoseidonChip +{ + type Config = PoseidonConfig; + type InputAssignments = InputAssignments; + type RunInputs = Vec; + type Params = (); + + fn name(&self) -> &'static str { + "Poseidon" + } + + fn instance_increment_input(&self) -> Vec { + vec![1] + } + + /// Constructs a new PoseidonChip + fn new(config: Self::Config) -> Self { + Self { + config, + _marker: PhantomData, + } + } + + /// Configuration of the PoseidonChip + fn configure(meta: &mut ConstraintSystem, _: Self::Params) -> Self::Config { + // instantiate the required columns + let hash_inputs = (0..WIDTH).map(|_| meta.advice_column()).collect::>(); + for input in &hash_inputs { + meta.enable_equality(*input); + } + + let partial_sbox = meta.advice_column(); + let rc_a = (0..WIDTH).map(|_| meta.fixed_column()).collect::>(); + let rc_b = (0..WIDTH).map(|_| meta.fixed_column()).collect::>(); + + for input in hash_inputs.iter().take(WIDTH) { + meta.enable_equality(*input); + } + meta.enable_constant(rc_b[0]); + + let instance = meta.instance_column(); + meta.enable_equality(instance); + + Self::configure_with_cols( + meta, + partial_sbox, + rc_a.try_into().unwrap(), + rc_b.try_into().unwrap(), + hash_inputs, + Some(instance), + ) + } + + fn layout_inputs( + &self, + layouter: &mut impl Layouter, + message: &[ValTensor], + ) -> Result { + assert_eq!(message.len(), 1); + let message = message[0].clone(); + + let start_time = instant::Instant::now(); + + let res = layouter.assign_region( + || "load message", + |mut region| { + let assigned_message: Result>, Error> = match &message { + ValTensor::Value { inner: v, .. } => v + .iter() + .enumerate() + .map(|(i, value)| { + let x = i % WIDTH; + let y = i / WIDTH; + + match value { + ValType::Value(v) => region.assign_advice( + || format!("load message_{}", i), + self.config.hash_inputs[x], + y, + || *v, + ), + ValType::PrevAssigned(v) | ValType::AssignedConstant(v, ..) => { + Ok(v.clone()) + } + ValType::Constant(f) => region.assign_advice_from_constant( + || format!("load message_{}", i), + self.config.hash_inputs[x], + y, + *f, + ), + e => { + log::error!( + "wrong input type {:?}, must be previously assigned", + e + ); + Err(Error::Synthesis) + } + } + }) + .collect(), + ValTensor::Instance { + dims, + inner: col, + idx, + initial_offset, + .. + } => { + // this should never ever fail + let num_elems = dims[*idx].iter().product::(); + (0..num_elems) + .map(|i| { + let x = i % WIDTH; + let y = i / WIDTH; + region.assign_advice_from_instance( + || "pub input anchor", + *col, + initial_offset + i, + self.config.hash_inputs[x], + y, + ) + }) + .collect() + } + }; + + let offset = message.len() / WIDTH + 1; + + let zero_val = region + .assign_advice_from_constant( + || "", + self.config.hash_inputs[0], + offset, + Fp::ZERO, + ) + .unwrap(); + + Ok((assigned_message?, zero_val)) + }, + ); + log::trace!( + "input (N={:?}) layout took: {:?}", + message.len(), + start_time.elapsed() + ); + res + } + + /// L is the number of inputs to the hash function + /// Takes the cells containing the input values of the hash function and return the cell containing the hash output + /// It uses the pow5_chip to compute the hash + fn layout( + &self, + layouter: &mut impl Layouter, + input: &[ValTensor], + row_offset: usize, + ) -> Result, Error> { + let (mut input_cells, zero_val) = self.layout_inputs(layouter, input)?; + // extract the values from the input cells + let mut assigned_input: Tensor> = + input_cells.iter().map(|e| ValType::from(e.clone())).into(); + let len = assigned_input.len(); + + let start_time = instant::Instant::now(); + + let mut one_iter = false; + // do the Tree dance baby + while input_cells.len() > 1 || !one_iter { + let hashes: Result>, Error> = input_cells + .chunks(L) + .enumerate() + .map(|(i, block)| { + let _start_time = instant::Instant::now(); + + let mut block = block.to_vec(); + let remainder = block.len() % L; + + if remainder != 0 { + block.extend(vec![zero_val.clone(); L - remainder]); + } + + let pow5_chip = Pow5Chip::construct(self.config.pow5_config.clone()); + // initialize the hasher + let hasher = Hash::<_, _, S, ConstantLength, WIDTH, RATE>::init( + pow5_chip, + layouter.namespace(|| "block_hasher"), + )?; + + let hash = hasher.hash( + layouter.namespace(|| "hash"), + block.to_vec().try_into().map_err(|_| Error::Synthesis)?, + ); + + if i == 0 { + log::trace!("block (L={:?}) took: {:?}", L, _start_time.elapsed()); + } + + hash + }) + .collect(); + + log::trace!("hashes (N={:?}) took: {:?}", len, start_time.elapsed()); + one_iter = true; + input_cells = hashes?; + } + + let duration = start_time.elapsed(); + log::trace!("layout (N={:?}) took: {:?}", len, duration); + + let result = Tensor::from(input_cells.iter().map(|e| ValType::from(e.clone()))); + + let output = match result[0].clone() { + ValType::PrevAssigned(v) => v, + _ => { + log::error!("wrong input type, must be previously assigned"); + return Err(Error::Synthesis); + } + }; + + if let Some(instance) = self.config.instance { + layouter.assign_region( + || "constrain output", + |mut region| { + let expected_var = region.assign_advice_from_instance( + || "pub input anchor", + instance, + row_offset, + self.config.hash_inputs[0], + 0, + )?; + + region.constrain_equal(output.cell(), expected_var.cell()) + }, + )?; + + assigned_input.reshape(input[0].dims()).map_err(|e| { + log::error!("reshape failed: {:?}", e); + Error::Synthesis + })?; + + Ok(assigned_input.into()) + } else { + Ok(result.into()) + } + } + + /// + fn run(message: Vec) -> Result>, Box> { + let mut hash_inputs = message; + + let len = hash_inputs.len(); + + let start_time = instant::Instant::now(); + + let mut one_iter = false; + // do the Tree dance baby + while hash_inputs.len() > 1 || !one_iter { + let hashes: Vec = hash_inputs + .par_chunks(L) + .map(|block| { + let mut block = block.to_vec(); + let remainder = block.len() % L; + + if remainder != 0 { + block.extend(vec![Fp::ZERO; L - remainder].iter()); + } + + let message = block.try_into().map_err(|_| Error::Synthesis)?; + + Ok(halo2_gadgets::poseidon::primitives::Hash::< + _, + S, + ConstantLength, + { WIDTH }, + { RATE }, + >::init() + .hash(message)) + }) + .collect::, Error>>()?; + one_iter = true; + hash_inputs = hashes; + } + + let duration = start_time.elapsed(); + log::trace!("run (N={:?}) took: {:?}", len, duration); + + Ok(vec![hash_inputs]) + } + + fn num_rows(mut input_len: usize) -> usize { + // this was determined by running the circuit and looking at the number of constraints + // in the test called hash_for_a_range_of_input_sizes, then regressing in python to find the slope + let fixed_cost: usize = 41 * L; + + let mut num_rows = 0; + + loop { + // the number of times the input_len is divisible by L + let num_chunks = input_len / L + 1; + num_rows += num_chunks * fixed_cost; + if num_chunks == 1 { + break; + } + input_len = num_chunks; + } + + num_rows + } +} + +#[allow(unused)] +mod tests { + + use crate::circuit::modules::ModulePlanner; + + use super::{ + spec::{PoseidonSpec, POSEIDON_RATE, POSEIDON_WIDTH}, + *, + }; + + use std::marker::PhantomData; + + use halo2_gadgets::poseidon::primitives::Spec; + use halo2_proofs::{ + circuit::{Layouter, SimpleFloorPlanner, Value}, + plonk::{Circuit, ConstraintSystem}, + }; + use halo2curves::ff::Field; + + const WIDTH: usize = POSEIDON_WIDTH; + const RATE: usize = POSEIDON_RATE; + const R: usize = 240; + + struct HashCircuit, const L: usize> { + message: ValTensor, + _spec: PhantomData, + } + + impl, const L: usize> Circuit for HashCircuit { + type Config = PoseidonConfig; + type FloorPlanner = ModulePlanner; + type Params = (); + + fn without_witnesses(&self) -> Self { + let empty_val: Vec> = vec![Value::::unknown().into()]; + let message: Tensor> = empty_val.into_iter().into(); + + Self { + message: message.into(), + _spec: PhantomData, + } + } + + fn configure(meta: &mut ConstraintSystem) -> PoseidonConfig { + PoseidonChip::::configure(meta, ()) + } + + fn synthesize( + &self, + config: PoseidonConfig, + mut layouter: impl Layouter, + ) -> Result<(), Error> { + let chip: PoseidonChip = PoseidonChip::new(config); + chip.layout(&mut layouter, &[self.message.clone()], 0)?; + + Ok(()) + } + } + + #[test] + fn poseidon_hash() { + let rng = rand::rngs::OsRng; + + let message = [Fp::random(rng), Fp::random(rng)]; + let output = PoseidonChip::::run(message.to_vec()).unwrap(); + + let mut message: Tensor> = + message.into_iter().map(|m| Value::known(m).into()).into(); + + let k = 9; + let circuit = HashCircuit:: { + message: message.into(), + _spec: PhantomData, + }; + let prover = halo2_proofs::dev::MockProver::run(k, &circuit, output).unwrap(); + assert_eq!(prover.verify_par(), Ok(())) + } + + #[test] + fn poseidon_hash_longer_input() { + let rng = rand::rngs::OsRng; + + let message = [Fp::random(rng), Fp::random(rng), Fp::random(rng)]; + let output = PoseidonChip::::run(message.to_vec()).unwrap(); + + let mut message: Tensor> = + message.into_iter().map(|m| Value::known(m).into()).into(); + + let k = 9; + let circuit = HashCircuit:: { + message: message.into(), + _spec: PhantomData, + }; + let prover = halo2_proofs::dev::MockProver::run(k, &circuit, output).unwrap(); + assert_eq!(prover.verify_par(), Ok(())) + } + + #[test] + #[ignore] + fn hash_for_a_range_of_input_sizes() { + let rng = rand::rngs::OsRng; + + #[cfg(not(target_arch = "wasm32"))] + env_logger::init(); + + { + let i = 32; + // print a bunch of new lines + println!( + "i is {} -------------------------------------------------", + i + ); + + let message: Vec = (0..i).map(|_| Fp::random(rng)).collect::>(); + let output = + PoseidonChip::::run(message.clone()).unwrap(); + + let mut message: Tensor> = + message.into_iter().map(|m| Value::known(m).into()).into(); + + let k = 17; + let circuit = HashCircuit:: { + message: message.into(), + _spec: PhantomData, + }; + let prover = halo2_proofs::dev::MockProver::run(k, &circuit, output).unwrap(); + + assert_eq!(prover.verify_par(), Ok(())) + } + } + + #[test] + #[ignore] + fn poseidon_hash_much_longer_input() { + let rng = rand::rngs::OsRng; + + let mut message: Vec = (0..2048).map(|_| Fp::random(rng)).collect::>(); + + let output = PoseidonChip::::run(message.clone()).unwrap(); + + let mut message: Tensor> = + message.into_iter().map(|m| Value::known(m).into()).into(); + + let k = 17; + let circuit = HashCircuit:: { + message: message.into(), + _spec: PhantomData, + }; + let prover = halo2_proofs::dev::MockProver::run(k, &circuit, output).unwrap(); + assert_eq!(prover.verify_par(), Ok(())) + } +} diff --git a/mnist_ezkl/src/circuit/modules/poseidon/poseidon_params.rs b/mnist_ezkl/src/circuit/modules/poseidon/poseidon_params.rs new file mode 100644 index 0000000..7eb1f4e --- /dev/null +++ b/mnist_ezkl/src/circuit/modules/poseidon/poseidon_params.rs @@ -0,0 +1,984 @@ +//! This file was generated by running generate_params.py +//! Number of round constants: 340 +//! Round constants for GF(p): +//! Parameters for using rate 4 Poseidon with the BN256 field. +//! The parameters can be reproduced by running the following Sage script from +//! [this repository](https://github.com/daira/pasta-hadeshash): +//! +//! ```text +//! $ sage generate_parameters_grain.sage 1 0 254 5 8 60 0x30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001 --rust +//! ``` +//! +//! where 1 means "prime field", 0 means "non-negative sbox", 254 is the bitsize +//! of the field, 5 is the Poseidon width (rate + 1), 8 is the number of full +//! rounds, 60 is the number of partial rounds. +//! More info here => +use halo2_proofs::halo2curves::bn256::Fr as Fp; +pub(crate) const ROUND_CONSTANTS: [[Fp; 2]; 64] = [ + [ + Fp::from_raw([ + 0x6c7d_c0db_d0ab_d7a7, + 0xa71a_a177_534c_dd1b, + 0xfe1f_aaba_294c_ba38, + 0x09c4_6e9e_c68e_9bd4, + ]), + Fp::from_raw([ + 0x3c1d_83ff_a604_cb81, + 0xc514_2b3a_e405_b834, + 0x2a97_ed93_7f31_35cf, + 0x0c03_5653_0896_eec4, + ]), + ], + [ + Fp::from_raw([ + 0x317e_a977_cc15_4a30, + 0xa00e_a5aa_bd62_68bd, + 0x142e_5118_2bb5_4cf4, + 0x1e28_a1d9_3569_8ad1, + ]), + Fp::from_raw([ + 0x4cf9_e2b1_2b91_251f, + 0x0e57_57c3_e008_db96, + 0x0809_65db_30e2_98e4, + 0x27af_2d83_1a9d_2748, + ]), + ], + [ + Fp::from_raw([ + 0x79aa_f435_45b7_4e03, + 0x4129_1462_f214_cd08, + 0x3a6a_3cfe_16ae_175a, + 0x1e6f_11ce_60fc_8f51, + ]), + Fp::from_raw([ + 0xf719_2062_68d1_42d3, + 0x0446_2ed1_4c36_13d8, + 0x8541_819c_b681_f0be, + 0x2a67_384d_3bbd_5e43, + ]), + ], + [ + Fp::from_raw([ + 0x3640_8f5d_5c9f_45d0, + 0xb985_e381_f025_1889, + 0x1609_f8e1_2fbf_ecf0, + 0x0b66_fdf3_5609_3a61, + ]), + Fp::from_raw([ + 0xdaa6_852d_bdb0_9e21, + 0x0b26_c83c_c5ce_beed, + 0x830c_6109_3c2a_de37, + 0x012e_e3ec_1e78_d470, + ]), + ], + [ + Fp::from_raw([ + 0x2d10_8e7b_445b_b1b9, + 0x6cd1_c431_b099_b6bb, + 0xfd88_f67f_8175_e3fd, + 0x0252_ba5f_6760_bfbd, + ]), + Fp::from_raw([ + 0xef5a_eaad_7ca9_32f1, + 0x5439_1a89_35ff_71d6, + 0x6c6b_ec3c_ef54_2963, + 0x1794_74cc_eca5_ff67, + ]), + ], + [ + Fp::from_raw([ + 0x7e1a_2589_bbed_2b91, + 0x9c1f_974a_2649_69b3, + 0x9228_ff4a_503f_d4ed, + 0x2c24_2613_79a5_1bfa, + ]), + Fp::from_raw([ + 0x53e6_6c05_5180_1b05, + 0xc2f6_3f50_01fc_0fc5, + 0xac2f_288b_d069_5b43, + 0x1cc1_d7b6_2692_e63e, + ]), + ], + [ + Fp::from_raw([ + 0x5d9e_ff5f_d9c9_1b56, + 0x0078_4dbf_17fb_acd0, + 0xb2ed_55f8_5297_9e96, + 0x2550_5930_1aad_a98b, + ]), + Fp::from_raw([ + 0xb11c_29ce_7e59_efd9, + 0xaea2_4234_970a_8193, + 0x79e1_f5c0_eccd_32b3, + 0x2843_7be3_ac1c_b2e4, + ]), + ], + [ + Fp::from_raw([ + 0x3387_62c3_7f5f_2043, + 0x1854_8da8_fb4f_78d4, + 0x1ca4_fa6b_5376_6eb1, + 0x2821_6a44_2f2e_1f71, + ]), + Fp::from_raw([ + 0x131f_2377_3234_82c9, + 0xeee1_efce_0309_4581, + 0x1f39_f4e7_056d_d03f, + 0x2c1f_47cd_17fa_5adf, + ]), + ], + [ + Fp::from_raw([ + 0x646b_8566_a621_afc9, + 0xd9da_fca2_7663_8a63, + 0x8632_bcc9_356c_eb7d, + 0x07ab_ad02_b7a5_ebc4, + ]), + Fp::from_raw([ + 0x37da_0c4d_15f9_6c3c, + 0x9429_f908_80a6_9cd1, + 0x275b_33ff_aab5_1dfe, + 0x0230_2646_01ff_df29, + ]), + ], + [ + Fp::from_raw([ + 0x717e_5d66_899a_a0a9, + 0xa864_4145_57ee_289e, + 0xa0f1_6865_6497_ca40, + 0x1bc9_7305_4e51_d905, + ]), + Fp::from_raw([ + 0x2a6b_2228_8f0a_67fc, + 0xd249_aff5_c2d8_421f, + 0x206c_3157_e863_41ed, + 0x2e1c_22f9_6443_5008, + ]), + ], + [ + Fp::from_raw([ + 0xa704_52bc_2bba_86b8, + 0x9e8e_a159_8e46_c9f7, + 0x121c_1d5f_461b_bc50, + 0x1224_f38d_f67c_5378, + ]), + Fp::from_raw([ + 0x69d2_9891_86cd_e20e, + 0xd7bf_e8cd_9dfe_da19, + 0x9280_b4bd_9ed0_068f, + 0x02e4_e69d_8ba5_9e51, + ]), + ], + [ + Fp::from_raw([ + 0x6d47_e973_5d98_018e, + 0x4f19_ee36_4e65_3f07, + 0x7f5d_f81f_c04f_f3ee, + 0x1f1e_ccc3_4aab_a013, + ]), + Fp::from_raw([ + 0xeacb_8a4d_4284_f582, + 0x1424_4480_32cd_1819, + 0x7426_6c30_39a9_a731, + 0x1672_ad3d_709a_3539, + ]), + ], + [ + Fp::from_raw([ + 0x1d2e_d602_df8c_8fc7, + 0xcda6_961f_284d_2499, + 0x56f4_4af5_192b_4ae9, + 0x283e_3fdc_2c6e_420c, + ]), + Fp::from_raw([ + 0x614f_bd69_ff39_4bcc, + 0x6837_51f8_fdff_59d6, + 0xd0db_0957_170f_a013, + 0x1c2a_3d12_0c55_0ecf, + ]), + ], + [ + Fp::from_raw([ + 0x96cb_6b81_7765_3fbd, + 0x143a_9a43_773e_a6f2, + 0xf789_7a73_2345_6efe, + 0x216f_8487_7aac_6172, + ]), + Fp::from_raw([ + 0x11a1_f515_52f9_4788, + 0xceaa_47ea_61ca_59a4, + 0x64ba_7e8e_3e28_d12b, + 0x2c0d_272b_ecf2_a757, + ]), + ], + [ + Fp::from_raw([ + 0xcb4a_6c3d_8954_6f43, + 0x170a_5480_abe0_508f, + 0x484e_e7a7_4c45_4e9f, + 0x16e3_4299_865c_0e28, + ]), + Fp::from_raw([ + 0x48cd_9397_5548_8fc5, + 0x7720_4776_5802_290f, + 0x375a_232a_6fb9_cc71, + 0x175c_eba5_99e9_6f5b, + ]), + ], + [ + Fp::from_raw([ + 0xd8c5_ffbb_44a1_ee32, + 0x6aa4_10bf_bc35_4f54, + 0xfead_9e17_58b0_2806, + 0x0c75_9444_0dc4_8c16, + ]), + Fp::from_raw([ + 0x9247_9882_d919_fd8d, + 0x760e_2001_3ccf_912c, + 0xc466_db7d_7eb6_fd8f, + 0x1a3c_29bc_39f2_1bb5, + ]), + ], + [ + Fp::from_raw([ + 0x95c8_eeab_cd22_e68f, + 0x0855_d349_074f_5a66, + 0xc098_6ea0_49b2_5340, + 0x0ccf_dd90_6f34_26e5, + ]), + Fp::from_raw([ + 0xe0e6_99b6_7dd9_e796, + 0x66a7_a8a3_fd06_5b3c, + 0x2bdb_475c_e6c9_4118, + 0x14f6_bc81_d9f1_86f6, + ]), + ], + [ + Fp::from_raw([ + 0x88ed_eb73_86b9_7052, + 0xcc09_9810_c9c4_95c8, + 0x9702_ca70_b2f6_c5aa, + 0x0962_b827_89fb_3d12, + ]), + Fp::from_raw([ + 0xafef_0c8f_6a31_a86d, + 0x1328_4ab0_1ef0_2575, + 0xbf20_c79d_e251_27bc, + 0x1a88_0af7_074d_18b3, + ]), + ], + [ + Fp::from_raw([ + 0x4c30_12bb_7ae9_311b, + 0x20af_2924_fc20_ff3f, + 0xcd5e_77f0_211c_154b, + 0x10cb_a184_19a6_a332, + ]), + Fp::from_raw([ + 0x756a_2849_f302_f10d, + 0xfa27_b731_9cae_3406, + 0xbdc7_6ba6_3a9e_aca8, + 0x057e_62a9_a8f8_9b3e, + ]), + ], + [ + Fp::from_raw([ + 0xafa0_413b_4428_0cee, + 0xb961_303b_bf65_cff5, + 0xd44a_df53_84b4_988c, + 0x287c_971d_e91d_c0ab, + ]), + Fp::from_raw([ + 0x6f7f_7960_e306_891d, + 0x1e56_2bc4_6d4a_ba4e, + 0xb3bc_a9da_0cca_908f, + 0x21df_3388_af16_87bb, + ]), + ], + [ + Fp::from_raw([ + 0x3eff_8b56_0e16_82b3, + 0x789d_f8f7_0b49_8fd8, + 0x3e25_cc97_4d09_34cd, + 0x1be5_c887_d25b_ce70, + ]), + Fp::from_raw([ + 0x48d5_9c27_06a0_d5c1, + 0xd2cb_5d42_fda5_acea, + 0x6811_7175_cea2_cd0d, + 0x268d_a36f_76e5_68fb, + ]), + ], + [ + Fp::from_raw([ + 0xbd06_460c_c26a_5ed6, + 0xc5d8_bb74_135e_bd05, + 0xc609_beaf_5510_ecec, + 0x0e17_ab09_1f6e_ae50, + ]), + Fp::from_raw([ + 0x040f_5caa_1f62_af40, + 0x91ef_62d8_cf83_d270, + 0x7aee_535a_b074_a430, + 0x04d7_27e7_28ff_a0a6, + ]), + ], + [ + Fp::from_raw([ + 0x2b15_417d_7e39_ca6e, + 0x3370_2ac1_0f1b_fd86, + 0x81b5_4976_2bc0_22ed, + 0x0ddb_d7bf_9c29_3415, + ]), + Fp::from_raw([ + 0x8a29_c49c_8789_654b, + 0x34f5_b0d1_d3af_9b58, + 0x7681_62e8_2989_c6c2, + 0x2790_eb33_5162_1752, + ]), + ], + [ + Fp::from_raw([ + 0x84b7_6420_6142_f9e9, + 0x395f_3d9a_b8b2_fd09, + 0x4471_9501_93d8_a570, + 0x1e45_7c60_1a63_b73e, + ]), + Fp::from_raw([ + 0xc4c6_86fc_46e0_91b0, + 0xfa90_ecd0_c43f_f91f, + 0x638d_6ab2_bbe7_135f, + 0x21ae_6430_1dca_9625, + ]), + ], + [ + Fp::from_raw([ + 0x5858_534e_ed8d_350b, + 0x854b_e9e3_432e_0955, + 0x4da2_9316_6f49_4928, + 0x0379_f63c_8ce3_468d, + ]), + Fp::from_raw([ + 0x8c9f_58a3_24c3_5049, + 0xca0e_4921_a466_86ac, + 0x6a74_4a08_0809_e054, + 0x002d_5642_0359_d026, + ]), + ], + [ + Fp::from_raw([ + 0x0fc2_c5af_9635_15a6, + 0xda8d_6245_9e21_f409, + 0x1d68_b3cd_32e1_0bbe, + 0x1231_58e5_965b_5d9b, + ]), + Fp::from_raw([ + 0x60c8_0eb4_9cad_9ec1, + 0x0fbb_2b6f_5283_6d4e, + 0x661d_14bb_f6cb_e042, + 0x0be2_9fc4_0847_a941, + ]), + ], + [ + Fp::from_raw([ + 0x2338_02f2_4fdf_4c1a, + 0x36db_9d85_9cad_5f9a, + 0x5771_6142_015a_453c, + 0x1ac9_6991_dec2_bb05, + ]), + Fp::from_raw([ + 0x51ca_3355_bcb0_627e, + 0x5e12_c9fa_97f1_8a92, + 0x5f49_64fc_61d2_3b3e, + 0x1596_443f_763d_bcc2, + ]), + ], + [ + Fp::from_raw([ + 0xd6d0_49ea_e3ba_3212, + 0xf185_7d9f_17e7_15ae, + 0x6b28_61d4_ec3a_eae0, + 0x12e0_bcd3_654b_dfa7, + ]), + Fp::from_raw([ + 0x04e6_c76c_7cf9_64ba, + 0xceab_ac7f_3715_4b19, + 0x9ea7_3d4a_f9af_2a50, + 0x0fc9_2b4f_1bbe_a82b, + ]), + ], + [ + Fp::from_raw([ + 0x9c7e_9652_3387_2762, + 0xb14f_7c77_2223_6f4f, + 0xd6f2_e592_a801_3f40, + 0x1f9c_0b16_1044_6442, + ]), + Fp::from_raw([ + 0x8d15_9f64_3dbb_f4d3, + 0x050d_914d_a38b_4c05, + 0xf8cd_e061_57a7_82f4, + 0x0ebd_7424_4ae7_2675, + ]), + ], + [ + Fp::from_raw([ + 0x7a83_9839_dccf_c6d1, + 0x3b06_71e9_7346_ee39, + 0x69a9_fafd_4ab9_51c0, + 0x2cb7_f0ed_39e1_6e9f, + ]), + Fp::from_raw([ + 0x90c7_2bca_7352_d9bf, + 0xce76_1d05_14ce_5266, + 0x5605_443e_e41b_ab20, + 0x1a9d_6e2e_cff0_22cc, + ]), + ], + [ + Fp::from_raw([ + 0x87da_182d_648e_c72f, + 0xd0c1_3326_a9a7_ba30, + 0x5ea8_3c3b_c44a_9331, + 0x2a11_5439_607f_335a, + ]), + Fp::from_raw([ + 0x9535_c115_c5a4_c060, + 0xe738_b563_05cd_44f2, + 0x15b8_fa7a_ee3e_3410, + 0x23f9_b652_9b5d_040d, + ]), + ], + [ + Fp::from_raw([ + 0x260e_b939_f0e6_e8a7, + 0xa3ce_97c1_6d58_b68b, + 0x249a_c6ba_484b_b9c3, + 0x0587_2c16_db0f_72a2, + ]), + Fp::from_raw([ + 0x2b62_4a7c_dedd_f6a7, + 0x0219_b615_1d55_b5c5, + 0xca20_fb80_1180_75f4, + 0x1300_bdee_08bb_7824, + ]), + ], + [ + Fp::from_raw([ + 0x072e_4e7b_7d52_b376, + 0x8d7a_d299_16d9_8cb1, + 0xe638_1786_3a8f_6c28, + 0x19b9_b63d_2f10_8e17, + ]), + Fp::from_raw([ + 0x24a2_0128_481b_4f7f, + 0x13d1_c887_26b5_ec42, + 0xb5bd_a237_6685_22f6, + 0x015b_ee13_57e3_c015, + ]), + ], + [ + Fp::from_raw([ + 0xea92_c785_b128_ffd1, + 0xfe1e_1ce4_bab2_18cb, + 0x1b97_07a4_f161_5e4e, + 0x2953_736e_94bb_6b9f, + ]), + Fp::from_raw([ + 0x4ce7_266e_d660_8dfc, + 0x851b_98d3_72b4_5f54, + 0x862f_8061_80c0_385f, + 0x0b06_9353_ba09_1618, + ]), + ], + [ + Fp::from_raw([ + 0x4f58_8ac9_7d81_f429, + 0x55ae_b7eb_9306_b64e, + 0x15e4_e0bc_fb93_817e, + 0x304f_74d4_61cc_c131, + ]), + Fp::from_raw([ + 0xb8ee_5415_cde9_13fc, + 0xaad2_a164_a461_7a4c, + 0xe8a3_3f5e_77df_e4f5, + 0x15bb_f146_ce9b_ca09, + ]), + ], + [ + Fp::from_raw([ + 0xa9ff_2385_9572_c8c6, + 0x9b8f_4b85_0405_c10c, + 0x4490_1031_4879_64ed, + 0x0ab4_dfe0_c274_2cde, + ]), + Fp::from_raw([ + 0x251d_e39f_9639_779a, + 0xef5e_edfe_a546_dea9, + 0x97f4_5f76_49a1_9675, + 0x0e32_db32_0a04_4e31, + ]), + ], + [ + Fp::from_raw([ + 0xa307_8efa_516d_a016, + 0x6797_733a_8277_4896, + 0xb276_35a7_8b68_88e6, + 0x0a17_56aa_1f37_8ca4, + ]), + Fp::from_raw([ + 0x4254_d6a2_a25d_93ef, + 0x95e6_1d32_8f85_efa9, + 0x47fd_1717_7f95_2ef8, + 0x044c_4a33_b10f_6934, + ]), + ], + [ + Fp::from_raw([ + 0xd37b_07b5_466c_4b8b, + 0xfe08_79d7_9a49_6891, + 0xbe65_5b53_7f66_f700, + 0x2ed3_611b_725b_8a70, + ]), + Fp::from_raw([ + 0xd833_9ea7_1208_58aa, + 0xadfd_eb9c_fdd3_47b5, + 0xc8ec_c3d7_22aa_2e0e, + 0x1f9b_a4e8_bab7_ce42, + ]), + ], + [ + Fp::from_raw([ + 0xb740_56f8_65c5_d3da, + 0xa38e_82ac_4502_066d, + 0x8f7e_e907_a84e_518a, + 0x1b23_3043_052e_8c28, + ]), + Fp::from_raw([ + 0xca2f_97b0_2087_5954, + 0x9020_53bf_c0f1_4db0, + 0x7403_1ab7_2bd5_5b4c, + 0x2431_e1cc_164b_b8d0, + ]), + ], + [ + Fp::from_raw([ + 0xa791_f273_9658_01fd, + 0xa13e_3220_9758_3319, + 0x30cd_6953_a0a7_db45, + 0x082f_934c_91f5_aac3, + ]), + Fp::from_raw([ + 0x9ad6_bb93_0c48_997c, + 0xc772_45e2_ae7c_be99, + 0xa34b_e074_3155_42a3, + 0x2b9a_0a22_3e75_38b0, + ]), + ], + [ + Fp::from_raw([ + 0xb0b5_89cc_7021_4e7d, + 0x8164_163e_75a8_a00e, + 0xceb8_5483_b887_a9be, + 0x0e1c_d91e_dd2c_fa2c, + ]), + Fp::from_raw([ + 0x88d3_2460_1ceb_e2f9, + 0x9977_4f19_854d_00f5, + 0xc951_f614_77e3_6989, + 0x2e1e_ac0f_2bfd_fd63, + ]), + ], + [ + Fp::from_raw([ + 0x23d7_4811_5b50_0b83, + 0x7345_784d_8efd_b33c, + 0x0c76_158e_769d_6d15, + 0x0cbf_a95f_37fb_7406, + ]), + Fp::from_raw([ + 0x980c_232d_fa4a_4f84, + 0x76d9_91e3_a775_13d9, + 0xd65a_d49d_8a61_e9a6, + 0x08f0_5b3b_e923_ed44, + ]), + ], + [ + Fp::from_raw([ + 0x25a2_dd51_0c04_7ef6, + 0xe728_4925_dc07_58a3, + 0x52bf_8e21_984d_0443, + 0x2271_9e2a_070b_cd08, + ]), + Fp::from_raw([ + 0xf41f_62b2_f268_30c0, + 0x7bdb_f036_1199_82c0, + 0xc060_f7fc_c3a1_ab4c, + 0x041f_596a_9ee1_cb2b, + ]), + ], + [ + Fp::from_raw([ + 0x19fc_dd09_86b1_0f89, + 0x021b_e1c2_d0dc_464a, + 0x8762_8eb0_6f6b_1d4c, + 0x233f_d35d_e1be_520a, + ]), + Fp::from_raw([ + 0xefcb_453c_61c9_c267, + 0xd31e_078a_a1b4_707e, + 0x4325_e0a4_23eb_c810, + 0x0524_b46d_1aa8_7a5e, + ]), + ], + [ + Fp::from_raw([ + 0xcc44_8623_7c51_5211, + 0x4227_bb95_4b0f_3199, + 0xce47_fcac_894b_8582, + 0x2c34_f424_c81e_5716, + ]), + Fp::from_raw([ + 0xf330_1032_7de4_915e, + 0x2dd2_025b_5457_cc97, + 0x207e_ffc2_b554_1fb7, + 0x0b5f_2a4b_6338_7819, + ]), + ], + [ + Fp::from_raw([ + 0xaefa_c41f_e05c_659f, + 0xc174_35d2_f57a_f6ce, + 0xc5b7_2fe4_39d2_cfd6, + 0x2220_7856_082c_cc54, + ]), + Fp::from_raw([ + 0x2785_4048_ce2c_8171, + 0xcdfb_2101_94ca_f79f, + 0x4e24_159b_7f89_50b5, + 0x24d5_7a8b_f5da_63fe, + ]), + ], + [ + Fp::from_raw([ + 0x7391_9bb2_3b79_396e, + 0x374a_d709_7bb0_1a85, + 0x3b37_1d75_bd69_3f98, + 0x0afa_b181_fdd5_e058, + ]), + Fp::from_raw([ + 0xf162_90d6_2b11_28ee, + 0x76c0_0571_94c1_6c0b, + 0x998a_52ef_ac7c_bd56, + 0x2dba_9b10_8f20_8772, + ]), + ], + [ + Fp::from_raw([ + 0x5aff_13e6_bce4_20b3, + 0xcbb8_3de0_bd59_2b25, + 0x56f8_81c7_88f5_3f83, + 0x2634_9b66_edb8_b16f, + ]), + Fp::from_raw([ + 0x2352_88a3_e6f1_37db, + 0xd81a_56d2_8ecc_193b, + 0x685e_95f9_2339_753a, + 0x25af_7ce0_e5e1_0357, + ]), + ], + [ + Fp::from_raw([ + 0x1f7c_0187_fe35_011f, + 0x70ee_d7aa_e88b_2bff, + 0xc094_d6a5_5edd_68b9, + 0x25b4_ce7b_d229_4390, + ]), + Fp::from_raw([ + 0x8cb9_d54c_1e02_b631, + 0xde9c_ef28_ebdf_30b1, + 0x387e_53f1_908a_88e5, + 0x22c5_43f1_0f6c_89ec, + ]), + ], + [ + Fp::from_raw([ + 0xdf66_8e74_882f_87a9, + 0x425e_906a_919d_7a34, + 0x4fc7_908a_9f19_1e1e, + 0x0236_f93e_7789_c472, + ]), + Fp::from_raw([ + 0x9cb4_97af_980c_4b52, + 0x652b_dae1_14eb_0165, + 0x0e7d_27e3_7d05_da99, + 0x2935_0b40_1166_ca01, + ]), + ], + [ + Fp::from_raw([ + 0xee12_6091_6652_363f, + 0x65ed_b75d_844e_bb89, + 0x6bd3_1bba_b547_f75a, + 0x0eed_787d_6582_0d3f, + ]), + Fp::from_raw([ + 0x1906_f656_f4de_6fad, + 0xfdcd_0e99_bd94_297d, + 0x036a_753f_520b_3291, + 0x07cc_1170_f13b_46f2, + ]), + ], + [ + Fp::from_raw([ + 0x2059_4356_89e8_acea, + 0x9087_86d7_f9f5_d10c, + 0xf49b_cf61_3a3d_30b1, + 0x22b9_3923_3b1d_7205, + ]), + Fp::from_raw([ + 0xadd6_50ac_e60a_e5a6, + 0x740f_083a_5aa8_5438, + 0x8aad_1dc8_bc33_e870, + 0x0145_1762_a0aa_b81c, + ]), + ], + [ + Fp::from_raw([ + 0xe704_fec0_892f_ce89, + 0xe32e_aa61_dec7_da57, + 0x61fa_bf10_25d4_6d1f, + 0x2350_6bb5_d872_7d44, + ]), + Fp::from_raw([ + 0x7f8b_d689_0735_5522, + 0x2a37_0953_1e1e_fea9, + 0xbac0_6ae3_f71b_dd09, + 0x2e48_4c44_e838_aea0, + ]), + ], + [ + Fp::from_raw([ + 0x4541_8da2_6835_b54c, + 0xaf4a_5945_45ce_dc25, + 0x379e_78c5_0bd2_e42b, + 0x0f4b_c7d0_7eba_fd64, + ]), + Fp::from_raw([ + 0xe620_996d_50d8_e74e, + 0x5158_2388_725d_f460, + 0xfa76_6378_62fa_aee8, + 0x1f4d_3c8f_6583_e9e5, + ]), + ], + [ + Fp::from_raw([ + 0x53eb_9bcb_48fe_7389, + 0xfae0_2abc_7b68_1d91, + 0x2660_d07b_e0e4_a988, + 0x0935_14e0_c707_11f8, + ]), + Fp::from_raw([ + 0x4a58_e0a3_47e1_53d8, + 0x43ee_83ec_e472_28f2, + 0x4669_9a2b_5f3b_c036, + 0x1ada_b0c8_e2b3_bad3, + ]), + ], + [ + Fp::from_raw([ + 0x1a22_dbef_9e80_dad2, + 0x378c_1b94_b807_2bac, + 0xd147_09eb_b474_641a, + 0x1672_b172_6057_d99d, + ]), + Fp::from_raw([ + 0x30d4_7b23_9b47_9c14, + 0xc5d8_e2fa_e0ac_c4ee, + 0x8f44_f53f_dcab_468c, + 0x1dfd_53d4_576a_f2e3, + ]), + ], + [ + Fp::from_raw([ + 0xbc7f_2077_5320_5c60, + 0xe6d7_7d64_0f6f_c3de, + 0xa70a_3626_3a37_e17f, + 0x0c68_88a1_0b75_b0f3, + ]), + Fp::from_raw([ + 0x8509_1ecc_a9d1_e508, + 0x611a_61e0_0ee6_848b, + 0x92b3_4a7e_77d1_2fe8, + 0x1add_b933_a65b_e770, + ]), + ], + [ + Fp::from_raw([ + 0x7935_628e_299d_1791, + 0xf638_ff54_25f0_afff, + 0x5c10_ae18_d1de_933c, + 0x00d7_540d_cd26_8a84, + ]), + Fp::from_raw([ + 0xd316_939d_20b8_2c0e, + 0x26fe_dde4_acd9_9db1, + 0x01b2_827a_5664_ca9c, + 0x140c_0e42_687e_9ead, + ]), + ], + [ + Fp::from_raw([ + 0xc091_e2ae_5656_5984, + 0xc20a_0f9b_24f8_c5ed, + 0x91ba_89b8_d13d_1806, + 0x2f0c_3a11_5d43_17d1, + ]), + Fp::from_raw([ + 0xd8c5_38a1_dc95_8c61, + 0x08a0_cff6_70b2_2b82, + 0x3006_ed22_0cf9_c810, + 0x0c4e_e778_ff7c_1455, + ]), + ], + [ + Fp::from_raw([ + 0x27c3_d748_5de7_4c69, + 0x9424_ed26_c0ac_c662, + 0x3693_f004_40cc_c360, + 0x1704_f276_6d46_f82c, + ]), + Fp::from_raw([ + 0x39b6_6fe9_009c_3cfa, + 0xf076_9c9f_8544_e402, + 0xa7a0_2c1b_51d2_44ab, + 0x2f2d_19cc_3ea5_d78e, + ]), + ], + [ + Fp::from_raw([ + 0xd6c7_66a8_06fc_6629, + 0xdd7e_e6cb_9cfe_d9c7, + 0x5053_f112_e2a8_e8dc, + 0x1ae0_3853_b75f_caba, + ]), + Fp::from_raw([ + 0x4e41_a86d_daf0_56d5, + 0x3556_921b_2d6f_014e, + 0x51d1_31d0_fa61_aa5f, + 0x0971_aabf_7952_41df, + ]), + ], + [ + Fp::from_raw([ + 0x5f5c_29f7_bfe2_f646, + 0xda62_4f83_80df_1c87, + 0x91d4_cf6b_6e0d_e73e, + 0x1408_c316_e601_4e1a, + ]), + Fp::from_raw([ + 0x4169_1f39_822e_f5bd, + 0x6c89_f1f7_73ef_2853, + 0x248a_be42_b543_093b, + 0x1667_f3fe_2edb_e850, + ]), + ], + [ + Fp::from_raw([ + 0x424c_6957_6500_fe37, + 0x5b81_7184_09e5_c133, + 0xa48b_0a03_557c_df91, + 0x13bf_7c5d_0d2c_4376, + ]), + Fp::from_raw([ + 0x19bc_0ba7_43a6_2c2c, + 0x024b_9534_7856_b797, + 0x3016_adf3_d353_3c24, + 0x0762_0a6d_fb0b_6cec, + ]), + ], + [ + Fp::from_raw([ + 0x1675_de3e_1982_b4d0, + 0x75d2_959e_2f32_2b73, + 0x36a8_ca08_bdbd_d8b0, + 0x1574_c7ef_0c43_545f, + ]), + Fp::from_raw([ + 0xc06e_03a7_ff83_78f0, + 0x5bd4_1845_71c2_54fd, + 0xfd56_7970_a717_ceec, + 0x269e_4b5b_7a2e_b21a, + ]), + ], +]; // n: 254 + // t: 5 + // N: 1270 + // Result Algorithm 1: + // [True, 0] + // Result Algorithm 2: + // [True, None] + // Result Algorithm 3: + // [True, None] + // Prime number: 0x0x30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001 + // MDS matrix: +pub(crate) const MDS: [[Fp; 2]; 2] = [ + [ + Fp::from_raw([ + 0xbcec_a70b_d2af_7ad5, + 0xaf07_f38a_f8c9_52a7, + 0xec10_3453_51a2_3a3a, + 0x066f_6f85_d6f6_8a85, + ]), + Fp::from_raw([ + 0x0546_2b9f_8125_b1e8, + 0x20a7_c02b_bd8b_ea73, + 0x7782_e150_9b1d_0fdb, + 0x2b9d_4b41_10c9_ae99, + ]), + ], + [ + Fp::from_raw([ + 0xf573_f431_221f_8ff9, + 0xb6c0_9d55_7013_fff1, + 0x2bf6_7a44_93cc_262f, + 0x0cc5_7cdb_b085_07d6, + ]), + Fp::from_raw([ + 0x21bc_d147_9432_03c8, + 0xade8_57e8_6eb5_c3a1, + 0xa31a_6ed6_9724_e1ad, + 0x1274_e649_a32e_d355, + ]), + ], +]; // Inverse MDS matrix: +pub(crate) const MDS_INV: [[Fp; 2]; 2] = [ + [ + Fp::from_raw([ + 0x8dbe_bd0f_a8c5_3e66, + 0x0554_569d_9b29_d1ea, + 0x7081_9ab1_c784_6f21, + 0x13ab_ec39_0ada_7f43, + ]), + Fp::from_raw([ + 0xaaf6_185b_1a1e_60fe, + 0xbd52_1ead_5dfe_0345, + 0x4c98_62a1_d97d_1510, + 0x1eb9_e1dc_19a3_3a62, + ]), + ], + [ + Fp::from_raw([ + 0x763f_7875_036b_cb02, + 0x8ce5_1690_30a2_ad69, + 0x601a_bc49_fdad_4f03, + 0x0fc1_c939_4db8_9bb2, + ]), + Fp::from_raw([ + 0x8abc_ed6b_d147_c8be, + 0x2b7e_ac34_3459_61bc, + 0x9502_054e_dc03_e7b2, + 0x16a9_e98c_493a_902b, + ]), + ], +]; diff --git a/mnist_ezkl/src/circuit/modules/poseidon/spec.rs b/mnist_ezkl/src/circuit/modules/poseidon/spec.rs new file mode 100644 index 0000000..2588e25 --- /dev/null +++ b/mnist_ezkl/src/circuit/modules/poseidon/spec.rs @@ -0,0 +1,49 @@ +//! This file was generated by running generate_params.py +//! Specification for rate 5 Poseidon using the BN256 curve. +//! Patterned after [halo2_gadgets::poseidon::primitives::P128Pow5T3] +use halo2_gadgets::poseidon::primitives::*; +use halo2_proofs::arithmetic::Field; +use halo2_proofs::halo2curves::bn256::Fr as Fp; + +use super::poseidon_params; + +/// The specification for the Poseidon hash function. +#[derive(Debug, Clone, Copy)] +pub struct PoseidonSpec; + +/// Basically the number of columns allocated within the poseidon chip +pub const POSEIDON_WIDTH: usize = 2; +/// The number of full SBox rounds +pub const POSEIDON_RATE: usize = 1; + +pub(crate) type Mds = [[Fp; T]; T]; + +impl Spec for PoseidonSpec { + fn full_rounds() -> usize { + 8 + } + + fn partial_rounds() -> usize { + 56 + } + + fn sbox(val: Fp) -> Fp { + val.pow_vartime([5]) + } + + fn secure_mds() -> usize { + unimplemented!() + } + + fn constants() -> ( + Vec<[Fp; POSEIDON_WIDTH]>, + Mds, + Mds, + ) { + ( + poseidon_params::ROUND_CONSTANTS[..].to_vec(), + poseidon_params::MDS, + poseidon_params::MDS_INV, + ) + } +} diff --git a/mnist_ezkl/src/circuit/ops/base.rs b/mnist_ezkl/src/circuit/ops/base.rs new file mode 100644 index 0000000..8c8d527 --- /dev/null +++ b/mnist_ezkl/src/circuit/ops/base.rs @@ -0,0 +1,158 @@ +use crate::tensor::TensorType; +use std::{ + fmt, + ops::{Add, Mul, Neg, Sub}, +}; + +#[allow(missing_docs)] +/// An enum representing the operations that can be used to express more complex operations via accumulation +#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)] +pub enum BaseOp { + Dot, + DotInit, + CumProdInit, + CumProd, + Identity, + Add, + Mult, + Sub, + SumInit, + Sum, + Neg, + Range { tol: i32 }, + IsZero, + IsBoolean, +} + +/// Matches a [BaseOp] to an operation over inputs +impl BaseOp { + /// forward func + pub fn nonaccum_f< + T: TensorType + Add + Sub + Mul + Neg, + >( + &self, + inputs: (T, T), + ) -> T { + let (a, b) = inputs; + match &self { + BaseOp::Add => a + b, + BaseOp::Identity => b, + BaseOp::Neg => -b, + BaseOp::Sub => a - b, + BaseOp::Mult => a * b, + BaseOp::Range { .. } => b, + BaseOp::IsZero => b, + BaseOp::IsBoolean => b, + _ => panic!("nonaccum_f called on accumulating operation"), + } + } + + /// forward func + pub fn accum_f< + T: TensorType + Add + Sub + Mul + Neg, + >( + &self, + prev_output: T, + a: Vec, + b: Vec, + ) -> T { + let zero = T::zero().unwrap(); + let one = T::one().unwrap(); + + match &self { + BaseOp::DotInit => a.into_iter().zip(b).fold(zero, |acc, (a, b)| acc + a * b), + BaseOp::Dot => prev_output + a.into_iter().zip(b).fold(zero, |acc, (a, b)| acc + a * b), + BaseOp::CumProdInit => b.into_iter().fold(one, |acc, b| acc * b), + BaseOp::CumProd => prev_output * b.into_iter().fold(one, |acc, b| acc * b), + BaseOp::SumInit => b.into_iter().fold(zero, |acc, b| acc + b), + BaseOp::Sum => prev_output + b.into_iter().fold(zero, |acc, b| acc + b), + _ => panic!("accum_f called on non-accumulating operation"), + } + } + + /// display func + pub fn as_str(&self) -> &'static str { + match self { + BaseOp::Identity => "IDENTITY", + BaseOp::Dot => "DOT", + BaseOp::DotInit => "DOTINIT", + BaseOp::CumProdInit => "CUMPRODINIT", + BaseOp::CumProd => "CUMPROD", + BaseOp::Add => "ADD", + BaseOp::Neg => "NEG", + BaseOp::Sub => "SUB", + BaseOp::Mult => "MULT", + BaseOp::Sum => "SUM", + BaseOp::SumInit => "SUMINIT", + BaseOp::Range { .. } => "RANGE", + BaseOp::IsZero => "ISZERO", + BaseOp::IsBoolean => "ISBOOLEAN", + } + } + + /// Returns the range of the query offset for this operation. + pub fn query_offset_rng(&self) -> (i32, usize) { + match self { + BaseOp::Identity => (0, 1), + BaseOp::Neg => (0, 1), + BaseOp::DotInit => (0, 1), + BaseOp::Dot => (-1, 2), + BaseOp::CumProd => (-1, 2), + BaseOp::CumProdInit => (0, 1), + BaseOp::Add => (0, 1), + BaseOp::Sub => (0, 1), + BaseOp::Mult => (0, 1), + BaseOp::Sum => (-1, 2), + BaseOp::SumInit => (0, 1), + BaseOp::Range { .. } => (0, 1), + BaseOp::IsZero => (0, 1), + BaseOp::IsBoolean => (0, 1), + } + } + + /// Returns the number of inputs for this operation. + pub fn num_inputs(&self) -> usize { + match self { + BaseOp::Identity => 1, + BaseOp::Neg => 1, + BaseOp::DotInit => 2, + BaseOp::Dot => 2, + BaseOp::CumProdInit => 1, + BaseOp::CumProd => 1, + BaseOp::Add => 2, + BaseOp::Sub => 2, + BaseOp::Mult => 2, + BaseOp::Sum => 1, + BaseOp::SumInit => 1, + BaseOp::Range { .. } => 1, + BaseOp::IsZero => 1, + BaseOp::IsBoolean => 1, + } + } + + /// Returns the number of outputs for this operation. + pub fn constraint_idx(&self) -> usize { + match self { + BaseOp::Identity => 0, + BaseOp::Neg => 0, + BaseOp::DotInit => 0, + BaseOp::Dot => 1, + BaseOp::Add => 0, + BaseOp::Sub => 0, + BaseOp::Mult => 0, + BaseOp::Range { .. } => 0, + BaseOp::Sum => 1, + BaseOp::SumInit => 0, + BaseOp::CumProd => 1, + BaseOp::CumProdInit => 0, + BaseOp::IsZero => 0, + BaseOp::IsBoolean => 0, + } + } +} + +impl fmt::Display for BaseOp { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{}", self.as_str()) + } +} diff --git a/mnist_ezkl/src/circuit/ops/chip.rs b/mnist_ezkl/src/circuit/ops/chip.rs new file mode 100644 index 0000000..fa2f28e --- /dev/null +++ b/mnist_ezkl/src/circuit/ops/chip.rs @@ -0,0 +1,530 @@ +use std::str::FromStr; + +use thiserror::Error; + +use halo2_proofs::{ + circuit::Layouter, + plonk::{ConstraintSystem, Constraints, Expression, Selector}, + poly::Rotation, +}; +use log::debug; +#[cfg(feature = "python-bindings")] +use pyo3::{ + conversion::{FromPyObject, PyTryFrom}, + exceptions::PyValueError, + prelude::*, + types::PyString, +}; +use serde::{Deserialize, Serialize}; + +use crate::{ + circuit::ops::base::BaseOp, + circuit::{table::Table, utils}, + tensor::{Tensor, TensorType, ValTensor, VarTensor}, +}; +use std::{collections::BTreeMap, error::Error, marker::PhantomData}; + +use super::{lookup::LookupOp, region::RegionCtx, Op}; +use halo2curves::ff::{Field, PrimeField}; + +/// circuit related errors. +#[derive(Debug, Error)] +pub enum CircuitError { + /// Shape mismatch in circuit construction + #[error("dimension mismatch in circuit construction for op: {0}")] + DimMismatch(String), + /// Error when instantiating lookup tables + #[error("failed to instantiate lookup tables")] + LookupInstantiation, + /// A lookup table was was already assigned + #[error("attempting to initialize an already instantiated lookup table")] + TableAlreadyAssigned, + /// This operation is unsupported + #[error("unsupported operation in graph")] + UnsupportedOp, + /// + #[error("invalid einsum expression")] + InvalidEinsum, +} + +#[allow(missing_docs)] +/// An enum representing activating the sanity checks we can perform on the accumulated arguments +#[derive( + Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize, Default, Copy, +)] +pub enum CheckMode { + #[default] + SAFE, + UNSAFE, +} + +impl From for CheckMode { + fn from(value: String) -> Self { + match value.to_lowercase().as_str() { + "safe" => CheckMode::SAFE, + "unsafe" => CheckMode::UNSAFE, + _ => { + log::error!("Invalid value for CheckMode"); + log::warn!("defaulting to SAFE"); + CheckMode::SAFE + } + } + } +} + +#[allow(missing_docs)] +/// An enum representing the tolerance we can accept for the accumulated arguments, either absolute or percentage +#[derive(Clone, Default, Debug, PartialEq, PartialOrd, Serialize, Deserialize, Copy)] +pub struct Tolerance { + pub val: f32, + pub scale: utils::F32, +} + +impl FromStr for Tolerance { + type Err = String; + + fn from_str(s: &str) -> Result { + if let Ok(val) = s.parse::() { + Ok(Tolerance { + val, + scale: utils::F32(1.0), + }) + } else { + Err( + "Invalid tolerance value provided. It should expressed as a percentage (f32)." + .to_string(), + ) + } + } +} + +impl From for Tolerance { + fn from(value: f32) -> Self { + Tolerance { + val: value, + scale: utils::F32(1.0), + } + } +} + +#[cfg(feature = "python-bindings")] +/// Converts CheckMode into a PyObject (Required for CheckMode to be compatible with Python) +impl IntoPy for CheckMode { + fn into_py(self, py: Python) -> PyObject { + match self { + CheckMode::SAFE => "safe".to_object(py), + CheckMode::UNSAFE => "unsafe".to_object(py), + } + } +} + +#[cfg(feature = "python-bindings")] +/// Obtains CheckMode from PyObject (Required for CheckMode to be compatible with Python) +impl<'source> FromPyObject<'source> for CheckMode { + fn extract(ob: &'source PyAny) -> PyResult { + let trystr = ::try_from(ob)?; + let strval = trystr.to_string(); + match strval.to_lowercase().as_str() { + "safe" => Ok(CheckMode::SAFE), + "unsafe" => Ok(CheckMode::UNSAFE), + _ => Err(PyValueError::new_err("Invalid value for CheckMode")), + } + } +} + +#[cfg(feature = "python-bindings")] +/// Converts Tolerance into a PyObject (Required for Tolerance to be compatible with Python) +impl IntoPy for Tolerance { + fn into_py(self, py: Python) -> PyObject { + (self.val, self.scale.0).to_object(py) + } +} + +#[cfg(feature = "python-bindings")] +/// Obtains Tolerance from PyObject (Required for Tolerance to be compatible with Python) +impl<'source> FromPyObject<'source> for Tolerance { + fn extract(ob: &'source PyAny) -> PyResult { + if let Ok((val, scale)) = ob.extract::<(f32, f32)>() { + Ok(Tolerance { + val, + scale: utils::F32(scale), + }) + } else { + Err(PyValueError::new_err("Invalid tolerance value provided. ")) + } + } +} + +/// Configuration for an accumulated arg. +#[derive(Clone, Debug, Default)] +pub struct BaseConfig { + /// the inputs to the accumulated operations. + pub inputs: Vec, + /// the VarTensor reserved for lookup operations (could be an element of inputs) + /// Note that you should be careful to ensure that the lookup_input is not simultaneously assigned to by other non-lookup operations eg. in the case of composite ops. + pub lookup_input: VarTensor, + /// the (currently singular) output of the accumulated operations. + pub output: VarTensor, + /// the VarTensor reserved for lookup operations (could be an element of inputs or the same as output) + /// Note that you should be careful to ensure that the lookup_output is not simultaneously assigned to by other non-lookup operations eg. in the case of composite ops. + pub lookup_output: VarTensor, + /// + pub lookup_index: VarTensor, + /// [Selector]s generated when configuring the layer. We use a [BTreeMap] as we expect to configure [BaseOp]. + pub selectors: BTreeMap<(BaseOp, usize, usize), Selector>, + /// [Selector]s generated when configuring the layer. We use a [BTreeMap] as we expect to configure many lookup ops. + pub lookup_selectors: BTreeMap<(LookupOp, usize, usize), Selector>, + /// + pub tables: BTreeMap>, + /// Activate sanity checks + pub check_mode: CheckMode, + _marker: PhantomData, +} + +impl BaseConfig { + /// Returns a new [BaseConfig] with no inputs, no selectors, and no tables. + pub fn dummy(col_size: usize, num_inner_cols: usize) -> Self { + let dummy_var = VarTensor::dummy(col_size, num_inner_cols); + + Self { + inputs: vec![dummy_var.clone(), dummy_var.clone()], + lookup_input: dummy_var.clone(), + output: dummy_var.clone(), + lookup_output: dummy_var.clone(), + lookup_index: dummy_var, + selectors: BTreeMap::new(), + lookup_selectors: BTreeMap::new(), + tables: BTreeMap::new(), + check_mode: CheckMode::SAFE, + _marker: PhantomData, + } + } + + /// Configures [BaseOp]s for a given [ConstraintSystem]. + /// # Arguments + /// * `meta` - The [ConstraintSystem] to configure the operations in. + /// * `inputs` - The explicit inputs to the operations. + /// * `output` - The variable representing the (currently singular) output of the operations. + /// * `check_mode` - The variable representing the (currently singular) output of the operations. + pub fn configure( + meta: &mut ConstraintSystem, + inputs: &[VarTensor; 2], + output: &VarTensor, + check_mode: CheckMode, + ) -> Self { + // setup a selector per base op + let mut nonaccum_selectors = BTreeMap::new(); + let mut accum_selectors = BTreeMap::new(); + + if !(inputs[0].num_cols() == inputs[1].num_cols()) { + log::warn!("input shapes do not match"); + } + if !(inputs[0].num_cols() == output.num_cols()) { + log::warn!("input and output shapes do not match"); + } + + for i in 0..output.num_blocks() { + for j in 0..output.num_inner_cols() { + nonaccum_selectors.insert((BaseOp::Add, i, j), meta.selector()); + nonaccum_selectors.insert((BaseOp::Sub, i, j), meta.selector()); + nonaccum_selectors.insert((BaseOp::Neg, i, j), meta.selector()); + nonaccum_selectors.insert((BaseOp::Mult, i, j), meta.selector()); + nonaccum_selectors.insert((BaseOp::IsZero, i, j), meta.selector()); + nonaccum_selectors.insert((BaseOp::Identity, i, j), meta.selector()); + nonaccum_selectors.insert((BaseOp::IsBoolean, i, j), meta.selector()); + } + } + + for i in 0..output.num_blocks() { + accum_selectors.insert((BaseOp::DotInit, i, 0), meta.selector()); + accum_selectors.insert((BaseOp::Dot, i, 0), meta.selector()); + accum_selectors.insert((BaseOp::CumProd, i, 0), meta.selector()); + accum_selectors.insert((BaseOp::CumProdInit, i, 0), meta.selector()); + accum_selectors.insert((BaseOp::Sum, i, 0), meta.selector()); + accum_selectors.insert((BaseOp::SumInit, i, 0), meta.selector()); + } + + for ((base_op, block_idx, inner_col_idx), selector) in nonaccum_selectors.iter() { + meta.create_gate(base_op.as_str(), |meta| { + let selector = meta.query_selector(*selector); + + let zero = Expression::::Constant(F::ZERO); + let mut qis = vec![zero; 2]; + for (i, q_i) in qis + .iter_mut() + .enumerate() + .take(2) + .skip(2 - base_op.num_inputs()) + { + *q_i = inputs[i] + .query_rng(meta, *block_idx, *inner_col_idx, 0, 1) + .expect("non accum: input query failed")[0] + .clone() + } + + // Get output expressions for each input channel + let (rotation_offset, rng) = base_op.query_offset_rng(); + + let constraints = match base_op { + BaseOp::IsBoolean => { + vec![(qis[1].clone()) * (qis[1].clone() - Expression::Constant(F::from(1)))] + } + BaseOp::IsZero => vec![qis[1].clone()], + _ => { + let expected_output: Tensor> = output + .query_rng(meta, *block_idx, *inner_col_idx, rotation_offset, rng) + .expect("non accum: output query failed"); + + let res = base_op.nonaccum_f((qis[0].clone(), qis[1].clone())); + vec![expected_output[base_op.constraint_idx()].clone() - res] + } + }; + + Constraints::with_selector(selector, constraints) + }); + } + + for ((base_op, block_idx, _), selector) in accum_selectors.iter() { + meta.create_gate(base_op.as_str(), |meta| { + let selector = meta.query_selector(*selector); + let mut qis = vec![vec![]; 2]; + for (i, q_i) in qis + .iter_mut() + .enumerate() + .take(2) + .skip(2 - base_op.num_inputs()) + { + *q_i = inputs[i] + .query_whole_block(meta, *block_idx, 0, 1) + .expect("accum: input query failed") + .into_iter() + .collect() + } + + // Get output expressions for each input channel + let (rotation_offset, rng) = base_op.query_offset_rng(); + + let expected_output: Tensor> = output + .query_rng(meta, *block_idx, 0, rotation_offset, rng) + .expect("accum: output query failed"); + + let res = + base_op.accum_f(expected_output[0].clone(), qis[0].clone(), qis[1].clone()); + let constraints = vec![expected_output[base_op.constraint_idx()].clone() - res]; + + Constraints::with_selector(selector, constraints) + }); + } + + // selectors is the merger of nonaccum and accum selectors + let selectors = nonaccum_selectors + .into_iter() + .chain(accum_selectors) + .collect(); + + Self { + selectors, + lookup_selectors: BTreeMap::new(), + inputs: inputs.to_vec(), + lookup_input: VarTensor::Empty, + lookup_output: VarTensor::Empty, + lookup_index: VarTensor::Empty, + tables: BTreeMap::new(), + output: output.clone(), + check_mode, + _marker: PhantomData, + } + } + + /// Configures and creates lookup selectors + #[allow(clippy::too_many_arguments)] + pub fn configure_lookup( + &mut self, + cs: &mut ConstraintSystem, + input: &VarTensor, + output: &VarTensor, + index: &VarTensor, + lookup_range: (i128, i128), + logrows: usize, + nl: &LookupOp, + ) -> Result<(), Box> + where + F: Field, + { + let mut selectors = BTreeMap::new(); + + if !index.is_advice() { + return Err("wrong input type for lookup index".into()); + } + if !input.is_advice() { + return Err("wrong input type for lookup input".into()); + } + if !output.is_advice() { + return Err("wrong input type for lookup output".into()); + } + + // we borrow mutably twice so we need to do this dance + + let table = if !self.tables.contains_key(nl) { + // as all tables have the same input we see if there's another table who's input we can reuse + let table = if let Some(table) = self.tables.values().next() { + Table::::configure( + cs, + lookup_range, + logrows, + nl, + Some(table.table_inputs.clone()), + ) + } else { + Table::::configure(cs, lookup_range, logrows, nl, None) + }; + self.tables.insert(nl.clone(), table.clone()); + table + } else { + return Ok(()); + }; + + for x in 0..input.num_blocks() { + for y in 0..input.num_inner_cols() { + let len = table.selector_constructor.degree; + + let multi_col_selector = cs.complex_selector(); + + for ((col_idx, input_col), output_col) in table + .table_inputs + .iter() + .enumerate() + .zip(table.table_outputs.iter()) + { + cs.lookup("", |cs| { + let mut res = vec![]; + let sel = cs.query_selector(multi_col_selector); + + let synthetic_sel = match len { + 1 => Expression::Constant(F::from(1)), + _ => match index { + VarTensor::Advice { inner: advices, .. } => { + cs.query_advice(advices[x][y], Rotation(0)) + } + _ => unreachable!(), + }, + }; + + let input_query = match &input { + VarTensor::Advice { inner: advices, .. } => { + cs.query_advice(advices[x][y], Rotation(0)) + } + _ => unreachable!(), + }; + + let output_query = match &output { + VarTensor::Advice { inner: advices, .. } => { + cs.query_advice(advices[x][y], Rotation(0)) + } + _ => unreachable!(), + }; + + // we index from 1 to avoid the zero element creating soundness issues + // this is 0 if the index is the same as the column index (starting from 1) + + let col_expr = sel.clone() + * table + .selector_constructor + .get_expr_at_idx(col_idx, synthetic_sel); + + let multiplier = + table.selector_constructor.get_selector_val_at_idx(col_idx); + + let not_expr = Expression::Constant(multiplier) - col_expr.clone(); + + let (default_x, default_y) = table.get_first_element(col_idx); + + log::trace!("---------------- col {:?} ------------------", col_idx,); + log::trace!("expr: {:?}", col_expr,); + log::trace!("multiplier: {:?}", multiplier); + log::trace!("not_expr: {:?}", not_expr); + log::trace!("default x: {:?}", default_x); + log::trace!("default y: {:?}", default_y); + + res.extend([ + ( + col_expr.clone() * input_query.clone() + + not_expr.clone() * Expression::Constant(default_x), + *input_col, + ), + ( + col_expr.clone() * output_query.clone() + + not_expr.clone() * Expression::Constant(default_y), + *output_col, + ), + ]); + + res + }); + } + selectors.insert((nl.clone(), x, y), multi_col_selector); + } + } + self.lookup_selectors.extend(selectors); + // if we haven't previously initialized the input/output, do so now + if let VarTensor::Empty = self.lookup_input { + debug!("assigning lookup input"); + self.lookup_input = input.clone(); + } + if let VarTensor::Empty = self.lookup_output { + debug!("assigning lookup output"); + self.lookup_output = output.clone(); + } + if let VarTensor::Empty = self.lookup_index { + debug!("assigning lookup index"); + self.lookup_index = index.clone(); + } + Ok(()) + } + + /// layout_tables must be called before layout. + pub fn layout_tables(&mut self, layouter: &mut impl Layouter) -> Result<(), Box> { + for (i, table) in self.tables.values_mut().enumerate() { + if !table.is_assigned { + debug!( + "laying out table for {}", + crate::circuit::ops::Op::::as_string(&table.nonlinearity) + ); + if i == 0 { + table.layout(layouter, false)?; + } else { + table.layout(layouter, true)?; + } + } + } + Ok(()) + } + + /// Assigns variables to the regions created when calling `configure`. + /// # Arguments + /// * `values` - The explicit values to the operations. + /// * `layouter` - A Halo2 Layouter. + /// * `op` - The operation being represented. + pub fn layout( + &mut self, + region: &mut RegionCtx, + values: &[ValTensor], + op: Box>, + ) -> Result>, Box> { + let res = op.layout(self, region, values)?; + + if matches!(&self.check_mode, CheckMode::SAFE) && !region.is_dummy() { + if let Some(claimed_output) = &res { + // during key generation this will be unknown vals so we use this as a flag to check + let mut is_assigned = !claimed_output.any_unknowns()?; + for val in values.iter() { + is_assigned = is_assigned && !val.any_unknowns()?; + } + if is_assigned { + op.safe_mode_check(claimed_output, values)?; + } + } + }; + Ok(res) + } +} diff --git a/mnist_ezkl/src/circuit/ops/hybrid.rs b/mnist_ezkl/src/circuit/ops/hybrid.rs new file mode 100644 index 0000000..88295bb --- /dev/null +++ b/mnist_ezkl/src/circuit/ops/hybrid.rs @@ -0,0 +1,460 @@ +use super::*; +use crate::{ + circuit::{self, layouts, utils, Tolerance}, + fieldutils::{felt_to_i128, i128_to_felt}, + tensor::{self, Tensor, TensorError, TensorType, ValTensor}, +}; +use halo2curves::ff::PrimeField; +use itertools::Itertools; +use serde::{Deserialize, Serialize}; +// import run args from model + +#[allow(missing_docs)] +/// An enum representing the operations that consist of both lookups and arithmetic operations. +#[derive(Clone, Debug, Serialize, Deserialize)] +pub enum HybridOp { + ReduceMax { + axes: Vec, + }, + ReduceArgMax { + dim: usize, + }, + MaxPool2d { + padding: [(usize, usize); 2], + stride: (usize, usize), + pool_dims: (usize, usize), + }, + ReduceMin { + axes: Vec, + }, + ReduceArgMin { + dim: usize, + }, + Softmax { + scale: utils::F32, + axes: Vec, + }, + RangeCheck(Tolerance), + Greater, + GreaterEqual, + Less, + LessEqual, + Equals, + Gather { + dim: usize, + constant_idx: Option>, + }, + TopK { + dim: usize, + k: usize, + }, + OneHot { + dim: usize, + num_classes: usize, + }, + GatherElements { + dim: usize, + constant_idx: Option>, + }, + ScatterElements { + dim: usize, + constant_idx: Option>, + }, +} + +impl Op for HybridOp { + /// + fn requires_homogenous_input_scales(&self) -> Vec { + match self { + HybridOp::Greater | HybridOp::Less | HybridOp::Equals => vec![0, 1], + HybridOp::ScatterElements { .. } => vec![0, 2], + _ => vec![], + } + } + + /// Returns a reference to the Any trait. + fn as_any(&self) -> &dyn Any { + self + } + /// Matches a [Op] to an operation in the `tensor::ops` module. + fn f(&self, inputs: &[Tensor]) -> Result, TensorError> { + let x = inputs[0].clone().map(|x| felt_to_i128(x)); + + let (res, intermediate_lookups) = match &self { + HybridOp::ReduceMax { axes, .. } => { + let res = tensor::ops::max_axes(&x, axes)?; + let max_minus_one = + Tensor::from(vec![x.clone().into_iter().max().unwrap() - 1].into_iter()); + let unit = Tensor::from(vec![1].into_iter()); + // relu(x - max(x - 1) + let inter_1 = (x.clone() - max_minus_one)?; + // relu(1 - sum(relu(inter_1))) + let inter_2 = (unit + - tensor::ops::sum(&tensor::ops::nonlinearities::leakyrelu(&inter_1, 0.0))?)?; + + (res.clone(), vec![inter_1, inter_2]) + } + HybridOp::ReduceMin { axes, .. } => { + let res = tensor::ops::min_axes(&x, axes)?; + let min_plus_one = + Tensor::from(vec![x.clone().into_iter().min().unwrap() + 1].into_iter()); + let unit = Tensor::from(vec![1].into_iter()); + // relu(min(x + 1) - x) + let inter_1 = (min_plus_one - x.clone())?; + // relu(1 - sum(relu(inter_1))) + let inter_2 = (unit + - tensor::ops::sum(&tensor::ops::nonlinearities::leakyrelu(&inter_1, 0.0))?)?; + (res.clone(), vec![inter_1, inter_2]) + } + HybridOp::ReduceArgMax { dim } => { + let res = tensor::ops::argmax_axes(&x, *dim)?; + let indices = Tensor::from(0..x.dims()[*dim] as i128); + let mut inter_equals: Vec> = vec![indices.clone(), -indices]; + let inter = + Op::f(&HybridOp::ReduceMax { axes: vec![*dim] }, inputs)?.intermediate_lookups; + inter_equals.extend(inter); + + (res.clone(), inter_equals) + } + HybridOp::ReduceArgMin { dim } => { + let res = tensor::ops::argmin_axes(&x, *dim)?; + let indices = Tensor::from(0..x.dims()[*dim] as i128); + let mut inter_equals: Vec> = vec![indices.clone(), -indices]; + let inter = + Op::f(&HybridOp::ReduceMin { axes: vec![*dim] }, inputs)?.intermediate_lookups; + inter_equals.extend(inter); + + (res.clone(), inter_equals) + } + HybridOp::Gather { dim, constant_idx } => { + if let Some(idx) = constant_idx { + log::debug!("idx: {}", idx.show()); + let res = tensor::ops::gather(&x, idx, *dim)?; + (res.clone(), vec![]) + } else { + let y = inputs[1].clone().map(|x| felt_to_i128(x)); + let indices = Tensor::from(0..x.dims()[*dim] as i128); + let inter_equals: Vec> = vec![indices.clone(), -indices]; + let res = tensor::ops::gather(&x, &y.map(|x| x as usize), *dim)?; + (res.clone(), inter_equals) + } + } + HybridOp::OneHot { dim, num_classes } => { + let indices = Tensor::from(0..x.dims()[*dim] as i128); + let inter_equals: Vec> = vec![indices.clone(), -indices]; + let res = tensor::ops::one_hot(&x, *num_classes, *dim)?; + (res.clone(), inter_equals) + } + HybridOp::TopK { dim, k } => { + let res = tensor::ops::topk_axes(&x, *k, *dim)?; + + let mut inter_equals = x + .clone() + .into_iter() + .flat_map(|elem| { + tensor::ops::equals(&res, &vec![elem].into_iter().into()) + .unwrap() + .1 + }) + .collect::>(); + + // sort in descending order and take pairwise differences + inter_equals.push( + x.into_iter() + .sorted() + .tuple_windows() + .map(|(a, b)| b - a) + .into(), + ); + + (res.clone(), inter_equals) + } + HybridOp::GatherElements { dim, constant_idx } => { + if let Some(idx) = constant_idx { + log::debug!("idx: {}", idx.show()); + let res = tensor::ops::gather_elements(&x, idx, *dim)?; + (res.clone(), vec![]) + } else { + let y = inputs[1].clone().map(|x| felt_to_i128(x)); + let indices = Tensor::from(0..x.dims()[*dim] as i128); + let inter_equals: Vec> = vec![indices.clone(), -indices]; + let res = tensor::ops::gather_elements(&x, &y.map(|x| x as usize), *dim)?; + (res.clone(), inter_equals) + } + } + HybridOp::ScatterElements { dim, constant_idx } => { + let src = inputs[2].clone().map(|x| felt_to_i128(x)); + if let Some(idx) = constant_idx { + log::debug!("idx: {}", idx.show()); + let res = tensor::ops::scatter(&x, idx, &src, *dim)?; + (res.clone(), vec![]) + } else { + let idx = inputs[1].clone().map(|x| felt_to_i128(x)); + let indices = Tensor::from(0..x.dims()[*dim] as i128); + let inter_equals: Vec> = vec![indices.clone(), -indices]; + let res = tensor::ops::scatter(&x, &idx.map(|x| x as usize), &src, *dim)?; + (res.clone(), inter_equals) + } + } + HybridOp::MaxPool2d { + padding, + stride, + pool_dims, + .. + } => { + let max_minus_one = + Tensor::from(vec![x.clone().into_iter().max().unwrap() - 1].into_iter()); + let unit = Tensor::from(vec![1].into_iter()); + // relu(x - max(x - 1) + let inter_1 = (x.clone() - max_minus_one)?; + // relu(1 - sum(relu(inter_1))) + let inter_2 = (unit + - tensor::ops::sum(&tensor::ops::nonlinearities::leakyrelu(&inter_1, 0.0))?)?; + ( + tensor::ops::max_pool2d(&x, padding, stride, pool_dims)?, + vec![inter_1, inter_2], + ) + } + HybridOp::Softmax { scale, axes } => { + tensor::ops::nonlinearities::softmax_axes(&x, scale.into(), axes) + } + HybridOp::RangeCheck(tol) => { + let y = inputs[1].clone().map(|x| felt_to_i128(x)); + ( + tensor::ops::nonlinearities::range_check_percent(&[x, y], 128, 128, tol.val), + vec![], + ) + } + HybridOp::Greater => { + let y = inputs[1].clone().map(|x| felt_to_i128(x)); + tensor::ops::greater(&x, &y)? + } + HybridOp::GreaterEqual => { + let y = inputs[1].clone().map(|x| felt_to_i128(x)); + tensor::ops::greater_equal(&x, &y)? + } + HybridOp::Less => { + let y = inputs[1].clone().map(|x| felt_to_i128(x)); + tensor::ops::less(&x, &y)? + } + HybridOp::LessEqual => { + let y = inputs[1].clone().map(|x| felt_to_i128(x)); + tensor::ops::less_equal(&x, &y)? + } + HybridOp::Equals => { + let y = inputs[1].clone().map(|x| felt_to_i128(x)); + tensor::ops::equals(&x, &y)? + } + }; + + // convert back to felt + let output = res.map(|x| i128_to_felt(x)); + + Ok(ForwardResult { + output, + intermediate_lookups, + }) + } + + fn as_string(&self) -> String { + match self { + HybridOp::ReduceMax { axes } => format!("REDUCEMAX (axes={:?})", axes), + HybridOp::ReduceArgMax { dim } => format!("REDUCEARGMAX (dim={})", dim), + HybridOp::MaxPool2d { + padding, + stride, + pool_dims, + } => format!( + "MAXPOOL2D (padding={:?}, stride={:?}, pool_dims={:?})", + padding, stride, pool_dims + ), + HybridOp::ReduceMin { axes } => format!("REDUCEMIN (axes={:?})", axes), + HybridOp::ReduceArgMin { dim } => format!("REDUCEARGMIN (dim={})", dim), + HybridOp::Softmax { scale, axes } => { + format!("SOFTMAX (scale={}, axes={:?})", scale, axes) + } + HybridOp::RangeCheck(p) => format!("RANGECHECK (tol={:?})", p), + HybridOp::Greater => "GREATER".into(), + HybridOp::GreaterEqual => "GREATEREQUAL".into(), + HybridOp::Less => "LESS".into(), + HybridOp::LessEqual => "LESSEQUAL".into(), + HybridOp::Equals => "EQUALS".into(), + HybridOp::Gather { dim, .. } => format!("GATHER (dim={})", dim), + HybridOp::TopK { k, dim } => format!("TOPK (k={}, dim={})", k, dim), + HybridOp::GatherElements { dim, .. } => format!("GATHERELEMENTS (dim={})", dim), + HybridOp::ScatterElements { dim, .. } => format!("SCATTERELEMENTS (dim={})", dim), + HybridOp::OneHot { dim, num_classes } => { + format!("ONEHOT (dim={}, num_classes={})", dim, num_classes) + } + } + } + + fn layout( + &self, + config: &mut crate::circuit::BaseConfig, + region: &mut RegionCtx, + values: &[ValTensor], + ) -> Result>, Box> { + Ok(Some(match self { + HybridOp::Gather { dim, constant_idx } => { + if let Some(idx) = constant_idx { + tensor::ops::gather(values[0].get_inner_tensor()?, idx, *dim)?.into() + } else { + layouts::gather(config, region, values[..].try_into()?, *dim)? + } + } + HybridOp::GatherElements { dim, constant_idx } => { + if let Some(idx) = constant_idx { + tensor::ops::gather_elements(values[0].get_inner_tensor()?, idx, *dim)?.into() + } else { + layouts::gather_elements(config, region, values[..].try_into()?, *dim)? + } + } + HybridOp::ScatterElements { dim, constant_idx } => { + if let Some(idx) = constant_idx { + tensor::ops::scatter( + values[0].get_inner_tensor()?, + idx, + values[1].get_inner_tensor()?, + *dim, + )? + .into() + } else { + layouts::scatter_elements(config, region, values[..].try_into()?, *dim)? + } + } + HybridOp::MaxPool2d { + padding, + stride, + pool_dims, + } => layouts::max_pool2d( + config, + region, + values[..].try_into()?, + *padding, + *stride, + *pool_dims, + )?, + HybridOp::ReduceMax { axes } => { + layouts::max_axes(config, region, values[..].try_into()?, axes)? + } + HybridOp::ReduceArgMax { dim } => { + layouts::argmax_axes(config, region, values[..].try_into()?, *dim)? + } + HybridOp::ReduceMin { axes } => { + layouts::min_axes(config, region, values[..].try_into()?, axes)? + } + HybridOp::ReduceArgMin { dim } => { + layouts::argmin_axes(config, region, values[..].try_into()?, *dim)? + } + HybridOp::Softmax { scale, axes } => { + layouts::softmax_axes(config, region, values[..].try_into()?, *scale, axes)? + } + HybridOp::RangeCheck(tol) => layouts::range_check_percent( + config, + region, + values[..].try_into()?, + tol.scale, + tol.val, + )?, + HybridOp::Greater => layouts::greater(config, region, values[..].try_into()?)?, + HybridOp::GreaterEqual => { + layouts::greater_equal(config, region, values[..].try_into()?)? + } + HybridOp::Less => layouts::less(config, region, values[..].try_into()?)?, + HybridOp::LessEqual => layouts::less_equal(config, region, values[..].try_into()?)?, + HybridOp::Equals => layouts::equals(config, region, values[..].try_into()?)?, + HybridOp::TopK { dim, k } => { + layouts::topk_axes(config, region, values[..].try_into()?, *k, *dim)? + } + HybridOp::OneHot { dim, num_classes } => { + layouts::one_hot_axis(config, region, values[..].try_into()?, *num_classes, *dim)? + } + })) + } + + fn out_scale(&self, in_scales: Vec) -> Result> { + let scale = match self { + HybridOp::Greater { .. } + | HybridOp::GreaterEqual { .. } + | HybridOp::Less { .. } + | HybridOp::LessEqual { .. } + | HybridOp::ReduceArgMax { .. } + | HybridOp::OneHot { .. } + | HybridOp::ReduceArgMin { .. } => 0, + HybridOp::Softmax { .. } => 2 * in_scales[0], + _ => in_scales[0], + }; + Ok(scale) + } + + fn required_lookups(&self) -> Vec { + match self { + HybridOp::ReduceMax { .. } + | HybridOp::ReduceMin { .. } + | HybridOp::MaxPool2d { .. } => Op::::required_lookups(&LookupOp::ReLU), + HybridOp::Softmax { scale, .. } => { + vec![ + LookupOp::Exp { scale: *scale }, + LookupOp::Recip { + scale: scale.0.powf(2.0).into(), + }, + ] + } + HybridOp::RangeCheck(tol) => { + let mut lookups = vec![]; + if tol.val > 0.0 { + let scale_squared = tol.scale.0.powf(2.0); + lookups.extend([ + LookupOp::Recip { + scale: scale_squared.into(), + }, + LookupOp::GreaterThan { + a: circuit::utils::F32((tol.val * scale_squared) / 100.0), + }, + ]); + } + lookups + } + HybridOp::Greater { .. } | HybridOp::Less { .. } => { + vec![LookupOp::GreaterThan { + a: circuit::utils::F32(0.), + }] + } + HybridOp::GreaterEqual { .. } | HybridOp::LessEqual { .. } => { + vec![LookupOp::GreaterThanEqual { + a: circuit::utils::F32(0.), + }] + } + HybridOp::TopK { .. } => { + vec![ + LookupOp::GreaterThan { + a: circuit::utils::F32(0.), + }, + LookupOp::KroneckerDelta, + ] + } + HybridOp::Gather { + constant_idx: None, .. + } + | HybridOp::OneHot { .. } + | HybridOp::GatherElements { + constant_idx: None, .. + } + | HybridOp::ScatterElements { + constant_idx: None, .. + } + | HybridOp::Equals { .. } => { + vec![LookupOp::KroneckerDelta] + } + HybridOp::ReduceArgMax { .. } | HybridOp::ReduceArgMin { .. } => { + vec![LookupOp::ReLU, LookupOp::KroneckerDelta] + } + _ => vec![], + } + } + + fn clone_dyn(&self) -> Box> { + Box::new(self.clone()) // Forward to the derive(Clone) impl + } +} diff --git a/mnist_ezkl/src/circuit/ops/layouts.rs b/mnist_ezkl/src/circuit/ops/layouts.rs new file mode 100644 index 0000000..ca513a4 --- /dev/null +++ b/mnist_ezkl/src/circuit/ops/layouts.rs @@ -0,0 +1,2829 @@ +use std::{ + collections::{HashMap, HashSet}, + error::Error, + ops::Range, +}; + +use halo2_proofs::circuit::Value; +use halo2curves::ff::PrimeField; +use itertools::Itertools; +use log::{error, trace}; +use rayon::{ + prelude::{IndexedParallelIterator, IntoParallelIterator, ParallelIterator}, + slice::ParallelSliceMut, +}; + +use super::{ + chip::{BaseConfig, CircuitError}, + region::RegionCtx, +}; +use crate::{ + circuit::{ops::base::BaseOp, utils}, + fieldutils::i128_to_felt, + tensor::{ + get_broadcasted_shape, + ops::{accumulated, add, mult, sub}, + Tensor, TensorError, ValType, + }, +}; + +use super::*; +use crate::circuit::ops::lookup::LookupOp; + +/// +pub fn overflowed_len(starting_idx: usize, mut total_len: usize, column_len: usize) -> usize { + let mut idx = starting_idx; + // let x = idx / column_len; + let y = idx % column_len; + if y + total_len < column_len { + return total_len; + } + // fill up first column + idx += column_len - y; + total_len += 1; + loop { + if idx >= starting_idx + total_len { + break; + } + idx += column_len; + total_len += 1; + } + total_len +} + +/// Dot product accumulated layout +pub fn dot( + config: &BaseConfig, + region: &mut RegionCtx, + values: &[ValTensor; 2], +) -> Result, Box> { + region.flush()?; + // time this entire function run + let global_start = instant::Instant::now(); + + let mut values = values.clone(); + + // this section has been optimized to death, don't mess with it + let mut removal_indices = values[0].get_const_zero_indices()?; + let second_zero_indices = values[1].get_const_zero_indices()?; + removal_indices.extend(second_zero_indices); + removal_indices.par_sort_unstable(); + removal_indices.dedup(); + + // is already sorted + values[0].remove_indices(&mut removal_indices, true)?; + values[1].remove_indices(&mut removal_indices, true)?; + + let elapsed = global_start.elapsed(); + trace!("filtering const zero indices took: {:?}", elapsed); + + if values[0].len() != values[1].len() { + return Err(Box::new(TensorError::DimMismatch("dot".to_string()))); + } + + // if empty return a const + if values[0].is_empty() && values[1].is_empty() { + return Ok(Tensor::from([ValType::Constant(F::ZERO)].into_iter()).into()); + } + + let start = instant::Instant::now(); + let mut inputs = vec![]; + let block_width = config.output.num_inner_cols(); + + let mut assigned_len = 0; + for (i, input) in values.iter_mut().enumerate() { + input.pad_to_zero_rem(block_width)?; + let inp = { + let (res, len) = region.assign_with_duplication( + &config.inputs[i], + input, + &config.check_mode, + false, + )?; + assigned_len = len; + res.get_inner()? + }; + inputs.push(inp); + } + + let elapsed = start.elapsed(); + trace!("assigning inputs took: {:?}", elapsed); + + // Now we can assign the dot product + // time this step + let start = instant::Instant::now(); + let accumulated_dot = accumulated::dot(&[inputs[0].clone(), inputs[1].clone()], block_width)?; + let elapsed = start.elapsed(); + trace!("calculating accumulated dot took: {:?}", elapsed); + + let start = instant::Instant::now(); + let (output, output_assigned_len) = region.assign_with_duplication( + &config.output, + &accumulated_dot.into(), + &config.check_mode, + true, + )?; + let elapsed = start.elapsed(); + trace!("assigning output took: {:?}", elapsed); + + // enable the selectors + if !region.is_dummy() { + (0..output_assigned_len) + .map(|i| { + let (x, _, z) = config + .output + .cartesian_coord(region.linear_coord() + i * block_width); + // hop over duplicates at start of column + if z == 0 && i > 0 { + return Ok(()); + } + let selector = if i == 0 { + config.selectors.get(&(BaseOp::DotInit, x, 0)) + } else { + config.selectors.get(&(BaseOp::Dot, x, 0)) + }; + region.enable(selector, z)?; + + Ok(()) + }) + .collect::, Box>>()?; + } + + let last_elem = output.get_slice(&[output.len() - 1..output.len()])?; + + region.increment(assigned_len); + + // last element is the result + + let elapsed = global_start.elapsed(); + trace!("dot layout took: {:?}, row {}", elapsed, region.row()); + trace!("----------------------------"); + Ok(last_elem) +} + +/// Einsum +pub fn einsum( + config: &BaseConfig, + region: &mut RegionCtx, + inputs: &[ValTensor], + equation: &str, +) -> Result, Box> { + let mut equation = equation.split("->"); + let inputs_eq = equation.next().ok_or(CircuitError::InvalidEinsum)?; + let output_eq = equation.next().ok_or(CircuitError::InvalidEinsum)?; + let inputs_eq = inputs_eq.split(',').collect::>(); + + // Check that the number of inputs matches the number of inputs in the equation + if inputs.len() != inputs_eq.len() { + return Err(Box::new(TensorError::DimMismatch("einsum".to_string()))); + } + + let mut indices_to_size = HashMap::new(); + for (i, input) in inputs.iter().enumerate() { + for j in 0..inputs_eq[i].len() { + let c = inputs_eq[i] + .chars() + .nth(j) + .ok_or(CircuitError::InvalidEinsum)?; + if let std::collections::hash_map::Entry::Vacant(e) = indices_to_size.entry(c) { + e.insert(input.dims()[j]); + } else if indices_to_size[&c] != input.dims()[j] { + return Err(Box::new(TensorError::DimMismatch("einsum".to_string()))); + } + } + } + + // maps unrepresented indices in the output to a trivial 1 + for c in output_eq.chars() { + indices_to_size.entry(c).or_insert(1); + } + + // Compute the output tensor shape + let mut output_shape: Vec = output_eq + .chars() + .map(|c| { + indices_to_size + .get(&c) + .ok_or(CircuitError::InvalidEinsum) + .copied() + }) + .collect::, _>>()?; + + if output_shape.is_empty() { + output_shape.push(1); + } + + // Create a new output tensor with the computed shape + let mut output: Tensor> = Tensor::new(None, &output_shape)?; + + let mut seen = HashSet::new(); + let mut common_indices_to_inputs = vec![]; + for input in inputs_eq.iter().take(inputs.len()) { + for c in input.chars() { + if !seen.contains(&c) { + seen.insert(c); + } else { + common_indices_to_inputs.push(c); + } + } + } + + let non_common_indices = indices_to_size + .keys() + .filter(|&x| !common_indices_to_inputs.contains(x)) + .collect::>(); + + let non_common_coord_size = non_common_indices + .iter() + .map(|d| { + // If the current index is in the output equation, then the slice should be the current coordinate + if output_eq.contains(**d) { + Ok(1) + // Otherwise, the slice should be the entire dimension of the input tensor + } else { + indices_to_size + .get(d) + .ok_or(CircuitError::InvalidEinsum) + .copied() + } + }) + .collect::, _>>()? + .iter() + .product::(); + + let cartesian_coord = output_shape + .iter() + .map(|d| 0..*d) + .multi_cartesian_product() + .collect::>(); + + // Get the indices common accross input tensors + let mut common_coord = common_indices_to_inputs + .iter() + .map(|d| { + // If the current index is in the output equation, then the slice should be the current coordinate + if output_eq.contains(*d) { + Ok(0..1) + // Otherwise, the slice should be the entire dimension of the input tensor + } else { + Ok(0..*indices_to_size.get(d).ok_or(CircuitError::InvalidEinsum)?) + } + }) + .collect::>, Box>>()? + .into_iter() + .multi_cartesian_product() + .collect::>(); + + // If there are no common indices, then we need to add an empty slice to force one iteration of the loop + if common_coord.is_empty() { + common_coord.push(vec![]); + } + + let inner_loop_function = |i: usize, region: &mut RegionCtx<'_, F>| { + let coord = cartesian_coord[i].clone(); + // Compute the slice of each input tensor given the current coordinate of the output tensor + let inputs = (0..inputs.len()) + .map(|idx| { + let mut slice = vec![]; + for (i, c) in inputs_eq[idx].chars().enumerate() { + // If the current index is in the output equation, then the slice should be the current coordinate + if let Some(idx) = output_eq.find(c) { + slice.push(coord[idx]..coord[idx] + 1); + // Otherwise, the slice should be the entire dimension of the input tensor + } else { + slice.push(0..inputs[idx].dims()[i]); + } + } + // Get the slice of the input tensor + inputs[idx].get_slice(&slice) + }) + .collect::, _>>()?; + + // in this case its just a dot product :) + if non_common_coord_size == 1 && inputs.len() == 2 { + Ok(dot( + config, + region, + inputs[..].try_into().map_err(|e| { + error!("{}", e); + halo2_proofs::plonk::Error::Synthesis + })?, + )? + .get_inner_tensor()?[0] + .clone()) + } else { + let mut prod = None; + + // Compute the cartesian product of all common indices + for common_dim in &common_coord { + let inputs = (0..inputs.len()) + .map(|idx| { + let mut slice = vec![]; + // Iterate over all indices in the input equation + for (i, c) in inputs_eq[idx].chars().enumerate() { + // If the current index is common to multiple inputs, then the slice should be the current coordinate + if let Some(j) = common_indices_to_inputs.iter().position(|&r| r == c) { + slice.push(common_dim[j]..common_dim[j] + 1); + } else { + slice.push(0..inputs[idx].dims()[i]); + } + } + // Get the slice of the input tensor + inputs[idx].get_slice(&slice).map_err(|e| { + error!("{}", e); + halo2_proofs::plonk::Error::Synthesis + }) + }) + .collect::, _>>()?; + + let input_pairs = inputs + .iter() + .map(|d| d.get_inner_tensor().into_iter()) + .multi_cartesian_product() + .collect::>(); + + // Compute the product of all input tensors + for pair in input_pairs { + let product_across_pair: Result, halo2_proofs::plonk::Error> = + pair[1..] + .iter() + .fold(Ok(ValTensor::from(pair[0].clone())), |acc, x| { + pairwise(config, region, &[acc?, (*x).clone().into()], BaseOp::Mult) + .map_err(|e| { + error!("{}", e); + halo2_proofs::plonk::Error::Synthesis + }) + }); + let product_across_pair = product_across_pair?; + + if let Some(product) = prod { + prod = Some( + pairwise(config, region, &[product, product_across_pair], BaseOp::Add) + .map_err(|e| { + error!("{}", e); + halo2_proofs::plonk::Error::Synthesis + })?, + ); + } else { + prod = Some(product_across_pair); + } + } + } + Ok::<_, region::RegionError>( + prod.ok_or(Into::::into("missing prod"))? + .get_inner_tensor()?[0] + .clone(), + ) + } + }; + + region.flush()?; + region.apply_in_loop(&mut output, inner_loop_function)?; + + let output: ValTensor = output.into(); + + Ok(output) +} + +fn _sort_descending( + config: &BaseConfig, + region: &mut RegionCtx, + values: &[ValTensor; 1], +) -> Result, Box> { + let input = values[0].clone(); + + // assert input is flat + assert_eq!(input.dims().len(), 1); + + let is_assigned = !input.any_unknowns()?; + + let sorted = if is_assigned { + input + .get_int_evals()? + .iter() + .sorted_by(|a, b| b.cmp(a)) + .map(|x| Value::known(i128_to_felt(*x))) + .collect::>() + } else { + Tensor::new( + Some(&vec![Value::::unknown(); input.len()]), + &[input.len()], + )? + }; + + let assigned_sort = region.assign(&config.inputs[0], &sorted.into())?; + let input = region.assign(&config.inputs[1], &input)?; + + let mut unit = Tensor::from(vec![F::from(1)].into_iter()); + unit.set_visibility(&crate::graph::Visibility::Fixed); + let unit = region.assign(&config.output, &unit.try_into()?)?; + + region.increment(assigned_sort.len()); + + for i in 0..assigned_sort.len() - 1 { + // assert that each thing in turn is larger than the next + let window_a = assigned_sort.get_slice(&[i..i + 1])?; + let window_b = assigned_sort.get_slice(&[i + 1..i + 2])?; + + let window_b_minus_1 = pairwise(config, region, &[window_b, unit.clone()], BaseOp::Sub)?; + + let diff = pairwise( + config, + region, + &[window_a.clone(), window_b_minus_1.clone()], + BaseOp::Sub, + )?; + let greater_than = nonlinearity( + config, + region, + &[diff], + &LookupOp::GreaterThan { a: 0.0.into() }, + )?; + + enforce_equality(config, region, &[unit.clone(), greater_than.clone()])?; + + // now assert that the elem is in the original vector + let is_present = equals(config, region, &[window_a, input.clone()])?; + let sum_equals = sum(config, region, &[is_present])?; + let greater_than = nonlinearity( + config, + region, + &[sum_equals], + &LookupOp::GreaterThan { a: 0.0.into() }, + )?; + + enforce_equality(config, region, &[unit.clone(), greater_than.clone()])?; + } + + Ok(assigned_sort) +} + +fn _sort_ascending( + config: &BaseConfig, + region: &mut RegionCtx, + values: &[ValTensor; 1], +) -> Result, Box> { + let input = values[0].clone(); + + // assert input is flat + assert_eq!(input.dims().len(), 1); + + let is_assigned = !input.any_unknowns()?; + + let sorted = if is_assigned { + input + .get_int_evals()? + .iter() + .sorted_by(|a, b| a.cmp(b)) + .map(|x| Ok(Value::known(input.get_felt_evals()?.get(&[*x as usize])))) + .collect::>, Box>>()? + } else { + Tensor::new( + Some(&vec![Value::::unknown(); input.len()]), + &[input.len()], + )? + }; + + let assigned_sort = region.assign(&config.inputs[0], &sorted.into())?; + + let mut unit = Tensor::from(vec![F::from(1)].into_iter()); + unit.set_visibility(&crate::graph::Visibility::Fixed); + let unit = region.assign(&config.inputs[1], &unit.try_into()?)?; + + region.increment(assigned_sort.len()); + + for i in 0..assigned_sort.len() - 1 { + // assert that each thing in turn is larger than the next + let window_a = assigned_sort.get_slice(&[i..i + 1])?; + let window_b = assigned_sort.get_slice(&[i + 1..i + 2])?; + + let window_a_minus_1 = pairwise( + config, + region, + &[window_a.clone(), unit.clone()], + BaseOp::Sub, + )?; + + let diff = pairwise( + config, + region, + &[window_b.clone(), window_a_minus_1.clone()], + BaseOp::Sub, + )?; + let greater_than = nonlinearity( + config, + region, + &[diff], + &LookupOp::GreaterThan { a: 0.0.into() }, + )?; + + enforce_equality(config, region, &[unit.clone(), greater_than.clone()])?; + + // now assert that the elem is in the original vector + let is_present = equals(config, region, &[window_a, input.clone()])?; + let sum_equals = sum(config, region, &[is_present])?; + let greater_than = nonlinearity( + config, + region, + &[sum_equals], + &LookupOp::GreaterThan { a: 0.0.into() }, + )?; + + enforce_equality(config, region, &[unit.clone(), greater_than.clone()])?; + } + + Ok(assigned_sort) +} + +/// +fn _select_topk( + config: &BaseConfig, + region: &mut RegionCtx, + values: &[ValTensor; 1], + k: usize, +) -> Result, Box> { + let sorted = _sort_descending(config, region, values)?.get_slice(&[0..k])?; + Ok(sorted) +} + +/// Select top k elements +pub fn topk_axes( + config: &BaseConfig, + region: &mut RegionCtx, + values: &[ValTensor; 1], + k: usize, + dim: usize, +) -> Result, Box> { + let topk_at_k = move |config: &BaseConfig, + region: &mut RegionCtx, + values: &[ValTensor; 1]| + -> Result, Box> { + _select_topk(config, region, values, k) + }; + + let output: ValTensor = multi_dim_axes_op(config, region, values, &[dim], topk_at_k)?; + + Ok(output) +} + +fn select( + config: &BaseConfig, + region: &mut RegionCtx, + values: &[ValTensor; 2], + dim_indices: ValTensor, +) -> Result, Box> { + let (mut input, index) = (values[0].clone(), values[1].clone()); + input.flatten(); + + // assert we have a single index + if !(index.dims().iter().product::() == 1) { + return Err("index must be a single element".into()); + } + + if !(dim_indices.all_prev_assigned() || region.is_dummy()) { + return Err("dim_indices must be assigned".into()); + } + + let is_assigned = !input.any_unknowns()? && !index.any_unknowns()?; + + let output: ValTensor = if is_assigned { + index + .get_int_evals()? + .iter() + .map(|x| Ok(Value::known(input.get_felt_evals()?.get(&[*x as usize])))) + .collect::>, Box>>()? + } else { + Tensor::new( + Some(&vec![Value::::unknown(); index.len()]), + &[index.len()], + )? + } + .into(); + + let local_mask = equals(config, region, &[dim_indices.clone(), index])?; + + let dot = dot(config, region, &[input.clone(), local_mask.clone()])?; + + let assigned_output = enforce_equality(config, region, &[dot, output.clone()])?; + + Ok(assigned_output) +} + +fn one_hot( + config: &BaseConfig, + region: &mut RegionCtx, + values: &[ValTensor; 1], + num_classes: usize, +) -> Result, Box> { + // assert values is flat + assert_eq!(values[0].dims().len(), 1); + // assert its a single elelemnt + assert_eq!(values[0].len(), 1); + let input = values[0].clone(); + let is_assigned = !input.any_unknowns()?; + + let output: ValTensor = if is_assigned { + let int_evals = input.get_int_evals()?; + let res = tensor::ops::one_hot(&int_evals, num_classes, 1)?; + res.iter() + .map(|x| Value::known(i128_to_felt(*x))) + .collect::>() + } else { + Tensor::new( + Some(&vec![Value::::unknown(); num_classes]), + &[num_classes], + )? + } + .into(); + + let assigned_input = region.assign(&config.inputs[0], &input)?; + + // now assert all elems are 0 or 1 + let assigned_output = region.assign(&config.inputs[1], &output)?; + if !region.is_dummy() { + for i in 0..assigned_output.len() { + let (x, y, z) = config.output.cartesian_coord(region.linear_coord() + i); + let selector = config.selectors.get(&(BaseOp::IsBoolean, x, y)); + region.enable(selector, z)?; + } + } + region.increment(std::cmp::max(assigned_output.len(), assigned_input.len())); + + let sum = sum(config, region, &[assigned_output.clone()])?; + // assert sum is 1 + let mut unit = Tensor::from(vec![F::from(1)].into_iter()); + unit.set_visibility(&crate::graph::Visibility::Fixed); + let unit: ValTensor = unit.try_into()?; + + enforce_equality(config, region, &[unit.clone(), sum])?; + + let gathered = gather( + config, + region, + &[assigned_output.clone(), assigned_input.clone()], + 0, + )?; + + enforce_equality(config, region, &[unit, gathered])?; + + Ok(assigned_output) +} + +/// One hot accumulated layout +pub fn one_hot_axis( + config: &BaseConfig, + region: &mut RegionCtx, + values: &[ValTensor; 1], + num_classes: usize, + dim: usize, +) -> Result, Box> { + let input = values[0].clone(); + let input_inner = input.get_inner_tensor()?; + + let mut output_dims = values[0].dims().to_vec(); + output_dims.insert(dim, num_classes); + + let mut op_tensors: Tensor> = Tensor::new(None, input_inner.dims())?; + + let inner_loop_function = + |i: usize, region: &mut RegionCtx<'_, F>| -> Result, _> { + let inp = input_inner[i].clone(); + let tensor = Tensor::new(Some(&[inp.clone()]), &[1])?; + + Ok(one_hot(config, region, &[tensor.into()], num_classes)?) + }; + + region.apply_in_loop(&mut op_tensors, inner_loop_function)?; + + // Allocate memory for the output tensor + let cartesian_coord = output_dims + .iter() + .map(|x| 0..*x) + .multi_cartesian_product() + .collect::>(); + + let mut output = Tensor::>::new(None, &output_dims)?; + + output = output.par_enum_map(|i, _| { + let coord = cartesian_coord[i].clone(); + let mut op_idx = coord.clone(); + let coord_at_dims = vec![coord[dim]]; + op_idx.remove(dim); + + let op_tensor = op_tensors.get(&op_idx); + + let op_tensor = op_tensor.get_inner_tensor()?; + + let one_hot_val = op_tensor.get(&coord_at_dims).clone(); + + Ok::<_, region::RegionError>(one_hot_val) + })?; + + Ok(output.into()) +} + +/// Gather accumulated layout +pub fn gather( + config: &BaseConfig, + region: &mut RegionCtx, + values: &[ValTensor; 2], + dim: usize, +) -> Result, Box> { + let (mut input, mut index_clone) = (values[0].clone(), values[1].clone()); + index_clone.flatten(); + if index_clone.is_singleton() { + index_clone.reshape(&[1])?; + } + + let mut assigned_len = vec![]; + if !input.all_prev_assigned() { + input = region.assign(&config.inputs[0], &input)?; + assigned_len.push(input.len()); + } + if !index_clone.all_prev_assigned() { + index_clone = region.assign(&config.inputs[1], &index_clone)?; + assigned_len.push(index_clone.len()); + } + + if !assigned_len.is_empty() { + region.increment(*assigned_len.iter().max().ok_or(TensorError::DimError)?); + } + + // Calculate the output tensor size + let input_dims = input.dims(); + let mut output_size = input_dims.to_vec(); + + output_size[dim] = index_clone.dims()[0]; + + // these will be assigned as constants + let mut indices = Tensor::from((0..input.dims()[dim] as u64).map(|x| F::from(x))); + indices.set_visibility(&crate::graph::Visibility::Fixed); + let indices = region.assign(&config.inputs[1], &indices.try_into()?)?; + region.increment(indices.len()); + + // Allocate memory for the output tensor + let cartesian_coord = output_size + .iter() + .map(|x| 0..*x) + .multi_cartesian_product() + .collect::>(); + + let mut output: Tensor> = Tensor::new(None, &output_size)?; + + let inner_loop_function = |i: usize, region: &mut RegionCtx<'_, F>| { + let coord = cartesian_coord[i].clone(); + let index_val = index_clone.get_single_elem(coord[dim])?; + + let mut slice = coord.iter().map(|x| *x..*x + 1).collect::>(); + slice[dim] = 0..input_dims[dim]; + + let mut sliced_input = input.get_slice(&slice)?; + sliced_input.flatten(); + + let res = select( + config, + region, + &[sliced_input, index_val.clone()], + indices.clone(), + )?; + + let res = res.get_inner_tensor()?; + + Ok(res[0].clone()) + }; + + region.apply_in_loop(&mut output, inner_loop_function)?; + + // Reshape the output tensor + if index_clone.is_singleton() { + output_size.remove(dim); + } + output.reshape(&output_size)?; + + Ok(output.into()) +} + +/// Gather accumulated layout +pub fn gather_elements( + config: &BaseConfig, + region: &mut RegionCtx, + values: &[ValTensor; 2], + dim: usize, +) -> Result, Box> { + let (mut input, mut index) = (values[0].clone(), values[1].clone()); + + assert_eq!(input.dims().len(), index.dims().len()); + + if !input.all_prev_assigned() { + input = region.assign(&config.inputs[0], &input)?; + } + if !index.all_prev_assigned() { + index = region.assign(&config.inputs[1], &index)?; + } + + region.increment(std::cmp::max(input.len(), index.len())); + + // Calculate the output tensor size + let input_dim = input.dims()[dim]; + let output_size = index.dims().to_vec(); + + // these will be assigned as constants + let mut indices = Tensor::from((0..input_dim as u64).map(|x| F::from(x))); + indices.set_visibility(&crate::graph::Visibility::Fixed); + let indices = region.assign(&config.inputs[1], &indices.try_into()?)?; + region.increment(indices.len()); + + // Allocate memory for the output tensor + let cartesian_coord = output_size + .iter() + .map(|x| 0..*x) + .multi_cartesian_product() + .collect::>(); + + let mut output = Tensor::new(None, &output_size)?; + + let inner_loop_function = |i: usize, region: &mut RegionCtx<'_, F>| { + let coord = cartesian_coord[i].clone(); + let index_val = index.get_inner_tensor()?.get(&coord); + + let mut slice = coord.iter().map(|x| *x..*x + 1).collect::>(); + slice[dim] = 0..input_dim; + + let mut sliced_input = input.get_slice(&slice)?; + sliced_input.flatten(); + + let index_valtensor: ValTensor = Tensor::from([index_val.clone()].into_iter()).into(); + + let res = select( + config, + region, + &[sliced_input, index_valtensor], + indices.clone(), + )?; + + let res = res.get_inner_tensor()?; + + Ok(res[0].clone()) + }; + + region.apply_in_loop(&mut output, inner_loop_function)?; + + Ok(output.into()) +} + +/// Gather accumulated layout +pub fn scatter_elements( + config: &BaseConfig, + region: &mut RegionCtx, + values: &[ValTensor; 3], + dim: usize, +) -> Result, Box> { + let (mut input, mut index, mut src) = (values[0].clone(), values[1].clone(), values[2].clone()); + + assert_eq!(input.dims().len(), index.dims().len()); + + let mut assigned_len = vec![]; + + if !input.all_prev_assigned() { + input = region.assign(&config.inputs[0], &input)?; + assigned_len.push(input.len()); + } + if !index.all_prev_assigned() { + index = region.assign(&config.inputs[1], &index)?; + assigned_len.push(index.len()); + } + if !src.all_prev_assigned() { + src = region.assign(&config.output, &src)?; + assigned_len.push(src.len()); + } + + if !assigned_len.is_empty() { + region.increment(*assigned_len.iter().max().ok_or(TensorError::DimError)?); + } + + // Calculate the output tensor size + let input_dim = input.dims()[dim]; + let output_size = index.dims().to_vec(); + + // these will be assigned as constants + let mut indices = Tensor::from((0..input_dim as u64).map(|x| F::from(x))); + indices.set_visibility(&crate::graph::Visibility::Fixed); + let indices = region.assign(&config.inputs[1], &indices.try_into()?)?; + region.increment(indices.len()); + + // Allocate memory for the output tensor + let cartesian_coord = output_size + .iter() + .map(|x| 0..*x) + .multi_cartesian_product() + .collect::>(); + + let mut unit = Tensor::from(vec![F::from(1)].into_iter()); + unit.set_visibility(&crate::graph::Visibility::Fixed); + let unit: ValTensor = unit.try_into()?; + region.assign(&config.inputs[1], &unit)?; + region.increment(1); + + let mut output: Tensor<()> = Tensor::new(None, &output_size)?; + + let mut inner_loop_function = |i: usize, region: &mut RegionCtx<'_, F>| { + let coord = cartesian_coord[i].clone(); + let index_val = index.get_inner_tensor()?.get(&coord); + + let src_val = src.get_inner_tensor()?.get(&coord); + let src_valtensor: ValTensor = Tensor::from([src_val.clone()].into_iter()).into(); + + let mut slice = coord.iter().map(|x| *x..*x + 1).collect::>(); + slice[dim] = 0..input_dim; + + let mut sliced_input = input.get_slice(&slice)?; + sliced_input.flatten(); + + let index_valtensor: ValTensor = Tensor::from([index_val.clone()].into_iter()).into(); + + let mask = equals(config, region, &[index_valtensor, indices.clone()])?; + + let one_minus_mask = pairwise(config, region, &[unit.clone(), mask.clone()], BaseOp::Sub)?; + + let pairwise_prod = pairwise(config, region, &[src_valtensor, mask], BaseOp::Mult)?; + let pairwise_prod_2 = pairwise( + config, + region, + &[sliced_input, one_minus_mask], + BaseOp::Mult, + )?; + + let res = pairwise( + config, + region, + &[pairwise_prod, pairwise_prod_2], + BaseOp::Add, + )?; + + let input_cartesian_coord = slice.into_iter().multi_cartesian_product(); + + let mutable_input_inner = input.get_inner_tensor_mut()?; + + for (i, r) in res.get_inner_tensor()?.iter().enumerate() { + let coord = input_cartesian_coord + .clone() + .nth(i) + .ok_or("invalid coord")?; + *mutable_input_inner.get_mut(&coord) = r.clone(); + } + Ok(()) + }; + + output + .iter_mut() + .enumerate() + .map(|(i, _)| inner_loop_function(i, region)) + .collect::, Box>>()?; + + Ok(input) +} + +/// sum accumulated layout +pub fn sum( + config: &BaseConfig, + region: &mut RegionCtx, + values: &[ValTensor; 1], +) -> Result, Box> { + region.flush()?; + // time this entire function run + let global_start = instant::Instant::now(); + + let mut values = values.clone(); + + // this section has been optimized to death, don't mess with it + let mut removal_indices = values[0].get_const_zero_indices()?; + removal_indices.par_sort_unstable(); + removal_indices.dedup(); + + // is already sorted + values[0].remove_indices(&mut removal_indices, true)?; + + let elapsed = global_start.elapsed(); + trace!("filtering const zero indices took: {:?}", elapsed); + + // if empty return a const + if values[0].is_empty() { + return Ok(Tensor::from([ValType::Constant(F::ZERO)].into_iter()).into()); + } + + let block_width = config.output.num_inner_cols(); + + let assigned_len: usize; + let input = { + let mut input = values[0].clone(); + input.pad_to_zero_rem(block_width)?; + let (res, len) = + region.assign_with_duplication(&config.inputs[1], &input, &config.check_mode, false)?; + assigned_len = len; + res.get_inner()? + }; + + // Now we can assign the dot product + let accumulated_sum = accumulated::sum(&input, block_width)?; + + let (output, output_assigned_len) = region.assign_with_duplication( + &config.output, + &accumulated_sum.into(), + &config.check_mode, + true, + )?; + + // enable the selectors + if !region.is_dummy() { + for i in 0..output_assigned_len { + let (x, _, z) = config + .output + .cartesian_coord(region.linear_coord() + i * block_width); + // skip over duplicates at start of column + if z == 0 && i > 0 { + continue; + } + let selector = if i == 0 { + config.selectors.get(&(BaseOp::SumInit, x, 0)) + } else { + config.selectors.get(&(BaseOp::Sum, x, 0)) + }; + + region.enable(selector, z)?; + } + } + + let last_elem = output.get_slice(&[output.len() - 1..output.len()])?; + + region.increment(assigned_len); + + // last element is the result + Ok(last_elem) +} + +/// product accumulated layout +pub fn prod( + config: &BaseConfig, + region: &mut RegionCtx, + values: &[ValTensor; 1], +) -> Result, Box> { + region.flush()?; + // time this entire function run + let global_start = instant::Instant::now(); + + // this section has been optimized to death, don't mess with it + let removal_indices = values[0].get_const_zero_indices()?; + + let elapsed = global_start.elapsed(); + trace!("finding const zero indices took: {:?}", elapsed); + // if empty return a const + if !removal_indices.is_empty() { + return Ok(Tensor::from([ValType::Constant(F::ZERO)].into_iter()).into()); + } + + let block_width = config.output.num_inner_cols(); + let assigned_len: usize; + let input = { + let mut input = values[0].clone(); + input.pad_to_zero_rem(block_width)?; + let (res, len) = + region.assign_with_duplication(&config.inputs[1], &input, &config.check_mode, false)?; + assigned_len = len; + res.get_inner()? + }; + + // Now we can assign the dot product + let accumulated_prod = accumulated::prod(&input, block_width)?; + + let (output, output_assigned_len) = region.assign_with_duplication( + &config.output, + &accumulated_prod.into(), + &config.check_mode, + true, + )?; + + // enable the selectors + if !region.is_dummy() { + (0..output_assigned_len) + .map(|i| { + let (x, _, z) = config + .output + .cartesian_coord(region.linear_coord() + i * block_width); + // skip over duplicates at start of column + if z == 0 && i > 0 { + return Ok(()); + } + let selector = if i == 0 { + config.selectors.get(&(BaseOp::CumProdInit, x, 0)) + } else { + config.selectors.get(&(BaseOp::CumProd, x, 0)) + }; + + region.enable(selector, z)?; + Ok(()) + }) + .collect::, Box>>()?; + } + + let last_elem = output.get_slice(&[output.len() - 1..output.len()])?; + + region.increment(assigned_len); + + // last element is the result + Ok(last_elem) +} + +/// Axes wise op wrapper +fn axes_wise_op( + config: &BaseConfig, + region: &mut RegionCtx, + values: &[ValTensor; 1], + axes: &[usize], + // generic layout op + op: impl Fn( + &BaseConfig, + &mut RegionCtx, + &[ValTensor; 1], + ) -> Result, Box> + + Send + + Sync, +) -> Result, Box> { + // calculate value of output + + let a = &values[0]; + + if axes.is_empty() { + return Ok(a.clone()); + } + + let mut new_dims = vec![]; + for i in 0..a.dims().len() { + if !axes.contains(&i) { + new_dims.push(a.dims()[i]); + } else { + new_dims.push(1); + } + } + + let mut res = Tensor::new(None, &new_dims)?; + + let cartesian_coord = new_dims + .iter() + .map(|x| 0..*x) + .multi_cartesian_product() + .collect::>(); + + let inner_loop_function = |i: usize, region: &mut RegionCtx<'_, F>| { + let coord = cartesian_coord[i].clone(); + let mut prod_dims = vec![]; + for (i, c) in coord.iter().enumerate() { + if axes.contains(&i) { + prod_dims.push(0..a.dims()[i]); + } else { + prod_dims.push(*c..*c + 1); + } + } + let values = a.get_slice(&prod_dims)?; + let op = op(config, region, &[values])?; + + Ok(op.get_inner_tensor()?[0].clone()) + }; + + region.apply_in_loop(&mut res, inner_loop_function)?; + + Ok(res.into()) +} + +/// Sum accumulated layout +pub fn prod_axes( + config: &BaseConfig, + region: &mut RegionCtx, + values: &[ValTensor; 1], + axes: &[usize], +) -> Result, Box> { + // calculate value of output + axes_wise_op(config, region, values, axes, prod) +} + +/// Sum accumulated layout +pub fn sum_axes( + config: &BaseConfig, + region: &mut RegionCtx, + values: &[ValTensor; 1], + axes: &[usize], +) -> Result, Box> { + // calculate value of output + axes_wise_op(config, region, values, axes, sum) +} + +/// argmax layout +pub fn argmax_axes( + config: &BaseConfig, + region: &mut RegionCtx, + values: &[ValTensor; 1], + dim: usize, +) -> Result, Box> { + // these will be assigned as constants + let mut indices = Tensor::from((0..values[0].dims()[dim] as u64).map(|x| F::from(x))); + indices.set_visibility(&crate::graph::Visibility::Fixed); + let indices = region.assign(&config.inputs[1], &indices.try_into()?)?; + region.increment(indices.len()); + + let argmax = move |config: &BaseConfig, + region: &mut RegionCtx, + values: &[ValTensor; 1]| + -> Result, Box> { + argmax(config, region, values, indices.clone()) + }; + + // calculate value of output + axes_wise_op(config, region, values, &[dim], argmax) +} + +/// Max accumulated layout +pub fn max_axes( + config: &BaseConfig, + region: &mut RegionCtx, + values: &[ValTensor; 1], + axes: &[usize], +) -> Result, Box> { + // calculate value of output + + axes_wise_op(config, region, values, axes, max) +} + +/// Argmin layout +pub fn argmin_axes( + config: &BaseConfig, + region: &mut RegionCtx, + values: &[ValTensor; 1], + dim: usize, +) -> Result, Box> { + // calculate value of output + // these will be assigned as constants + let mut indices = Tensor::from((0..values[0].dims()[dim] as u64).map(|x| F::from(x))); + indices.set_visibility(&crate::graph::Visibility::Fixed); + let indices = region.assign(&config.inputs[1], &indices.try_into()?)?; + region.increment(indices.len()); + + let argmin = move |config: &BaseConfig, + region: &mut RegionCtx, + values: &[ValTensor; 1]| + -> Result, Box> { + argmin(config, region, values, indices.clone()) + }; + + axes_wise_op(config, region, values, &[dim], argmin) +} + +/// Min accumulated layout +pub fn min_axes( + config: &BaseConfig, + region: &mut RegionCtx, + values: &[ValTensor; 1], + axes: &[usize], +) -> Result, Box> { + // calculate value of output + + axes_wise_op(config, region, values, axes, min) +} + +/// Pairwise (elementwise) op layout +pub fn pairwise( + config: &BaseConfig, + region: &mut RegionCtx, + values: &[ValTensor; 2], + op: BaseOp, +) -> Result, Box> { + // time to calculate the value of the output + let global_start = instant::Instant::now(); + + let (mut lhs, mut rhs) = (values[0].clone(), values[1].clone()); + + let broadcasted_shape = get_broadcasted_shape(lhs.dims(), rhs.dims())?; + + lhs.expand(&broadcasted_shape)?; + rhs.expand(&broadcasted_shape)?; + + // original values + let orig_lhs = lhs.clone(); + let orig_rhs = rhs.clone(); + + // get indices of zeros + let first_zero_indices = lhs.get_const_zero_indices()?; + let second_zero_indices = rhs.get_const_zero_indices()?; + let mut removal_indices = match op { + BaseOp::Add | BaseOp::Mult => { + let mut removal_indices = first_zero_indices.clone(); + removal_indices.extend(second_zero_indices.clone()); + removal_indices + } + BaseOp::Sub => second_zero_indices.clone(), + _ => return Err(Box::new(CircuitError::UnsupportedOp)), + }; + removal_indices.dedup(); + + let removal_indices: HashSet<&usize> = HashSet::from_iter(removal_indices.iter()); + let removal_indices_ptr = &removal_indices; + + if lhs.len() != rhs.len() { + return Err(Box::new(CircuitError::DimMismatch(format!( + "pairwise {} layout", + op.as_str() + )))); + } + + let mut inputs = vec![]; + for (i, input) in [lhs.clone(), rhs.clone()].iter().enumerate() { + let inp = { + let res = + region.assign_with_omissions(&config.inputs[i], input, removal_indices_ptr)?; + + res.get_inner()? + }; + + inputs.push(inp); + } + + // Now we can assign the dot product + // time the calc + let start = instant::Instant::now(); + let op_result = match op { + BaseOp::Add => add(&inputs), + BaseOp::Sub => sub(&inputs), + BaseOp::Mult => mult(&inputs), + _ => return Err(Box::new(CircuitError::UnsupportedOp)), + } + .map_err(|e| { + error!("{}", e); + halo2_proofs::plonk::Error::Synthesis + })?; + let elapsed = start.elapsed(); + + let assigned_len = inputs[0].len() - removal_indices.len(); + let mut output = + region.assign_with_omissions(&config.output, &op_result.into(), removal_indices_ptr)?; + trace!("pairwise {} calc took {:?}", op.as_str(), elapsed); + + // Enable the selectors + if !region.is_dummy() { + (0..assigned_len) + .map(|i| { + let (x, y, z) = config.inputs[0].cartesian_coord(region.linear_coord() + i); + let selector = config.selectors.get(&(op.clone(), x, y)); + + region.enable(selector, z)?; + + Ok(()) + }) + .collect::, Box>>()?; + } + region.increment(assigned_len); + + let a_tensor = orig_lhs.get_inner_tensor()?; + let b_tensor = orig_rhs.get_inner_tensor()?; + + let first_zero_indices: HashSet<&usize> = HashSet::from_iter(first_zero_indices.iter()); + let second_zero_indices: HashSet<&usize> = HashSet::from_iter(second_zero_indices.iter()); + + trace!("setting up indices took {:?}", start.elapsed()); + + // infill the zero indices with the correct values from values[0] or values[1] + if !removal_indices_ptr.is_empty() { + output + .get_inner_tensor_mut()? + .par_enum_map_mut_filtered(removal_indices_ptr, |i| { + let val = match op { + BaseOp::Add => { + let a_is_null = first_zero_indices.contains(&i); + let b_is_null = second_zero_indices.contains(&i); + + if a_is_null && b_is_null { + ValType::Constant(F::ZERO) + } else if a_is_null { + b_tensor[i].clone() + } else { + a_tensor[i].clone() + } + } + BaseOp::Sub => { + let a_is_null = first_zero_indices.contains(&i); + // by default b is null in this case for sub + if a_is_null { + ValType::Constant(F::ZERO) + } else { + a_tensor[i].clone() + } + } + BaseOp::Mult => ValType::Constant(F::ZERO), + // can safely panic as the prior check ensures this is not called + _ => unreachable!(), + }; + Ok::<_, TensorError>(val) + })?; + } + + output.reshape(&broadcasted_shape)?; + + let end = global_start.elapsed(); + trace!( + "pairwise {} layout took {:?}, row: {}", + op.as_str(), + end, + region.row() + ); + + Ok(output) +} + +/// +pub fn greater( + config: &BaseConfig, + region: &mut RegionCtx, + values: &[ValTensor; 2], +) -> Result, Box> { + let (mut lhs, mut rhs) = (values[0].clone(), values[1].clone()); + + let broadcasted_shape = get_broadcasted_shape(lhs.dims(), rhs.dims())?; + + lhs.expand(&broadcasted_shape)?; + rhs.expand(&broadcasted_shape)?; + + let diff = pairwise(config, region, &[lhs, rhs], BaseOp::Sub)?; + + nonlinearity( + config, + region, + &[diff], + &LookupOp::GreaterThan { a: utils::F32(0.) }, + ) +} + +/// +pub fn greater_equal( + config: &BaseConfig, + region: &mut RegionCtx, + values: &[ValTensor; 2], +) -> Result, Box> { + let (mut lhs, mut rhs) = (values[0].clone(), values[1].clone()); + + let broadcasted_shape = get_broadcasted_shape(lhs.dims(), rhs.dims())?; + + lhs.expand(&broadcasted_shape)?; + rhs.expand(&broadcasted_shape)?; + + let diff = pairwise(config, region, &[lhs, rhs], BaseOp::Sub)?; + + nonlinearity( + config, + region, + &[diff], + &LookupOp::GreaterThanEqual { a: utils::F32(0.) }, + ) +} + +/// +pub fn less( + config: &BaseConfig, + region: &mut RegionCtx, + values: &[ValTensor; 2], +) -> Result, Box> { + // just flip the order and use greater + greater(config, region, &[values[1].clone(), values[0].clone()]) +} + +/// +pub fn less_equal( + config: &BaseConfig, + region: &mut RegionCtx, + values: &[ValTensor; 2], +) -> Result, Box> { + // just flip the order and use greater + greater_equal(config, region, &[values[1].clone(), values[0].clone()]) +} + +/// And boolean operation +pub fn and( + config: &BaseConfig, + region: &mut RegionCtx, + values: &[ValTensor; 2], +) -> Result, Box> { + let res = pairwise(config, region, values, BaseOp::Mult)?; + + Ok(res) +} + +/// Or boolean operation +pub fn or( + config: &BaseConfig, + region: &mut RegionCtx, + values: &[ValTensor; 2], +) -> Result, Box> { + let a = values[0].clone(); + let b = values[1].clone(); + + let iff_values = &[a.clone(), a, b]; + + let res = iff(config, region, iff_values)?; + + Ok(res) +} + +/// Equality boolean operation +pub fn equals( + config: &BaseConfig, + region: &mut RegionCtx, + values: &[ValTensor; 2], +) -> Result, Box> { + let diff = pairwise(config, region, values, BaseOp::Sub)?; + + let res = nonlinearity(config, region, &[diff], &LookupOp::KroneckerDelta)?; + + Ok(res) +} + +/// Xor boolean operation +pub fn xor( + config: &BaseConfig, + region: &mut RegionCtx, + values: &[ValTensor; 2], +) -> Result, Box> { + let lhs = values[0].clone(); + let rhs = values[1].clone(); + + let lhs_not = not(config, region, &[lhs.clone()])?; + let rhs_not = not(config, region, &[rhs.clone()])?; + + let lhs_and_rhs_not = and(config, region, &[lhs, rhs_not.clone()])?; + let lhs_not_and_rhs = and(config, region, &[rhs, lhs_not])?; + + // we can safely use add and not OR here because we know that lhs_and_rhs_not and lhs_not_and_rhs are =1 at different incices + let res: ValTensor = pairwise( + config, + region, + &[lhs_and_rhs_not, lhs_not_and_rhs], + BaseOp::Add, + )?; + + Ok(res) +} + +/// Not boolean operation +pub fn not( + config: &BaseConfig, + region: &mut RegionCtx, + values: &[ValTensor; 1], +) -> Result, Box> { + let mask = values[0].clone(); + + let unit: ValTensor = + Tensor::from(vec![region.assign_constant(&config.inputs[0], F::from(1))?].into_iter()) + .into(); + + // to leverage sparsity we don't assign this guy + let nil: ValTensor = Tensor::from(vec![ValType::Constant(F::from(0))].into_iter()).into(); + region.next(); + + let res = iff(config, region, &[mask, nil, unit])?; + + Ok(res) +} + +/// Iff +pub fn iff( + config: &BaseConfig, + region: &mut RegionCtx, + values: &[ValTensor; 3], +) -> Result, Box> { + // if mask > 0 then output a else output b + let (mask, a, b) = (&values[0], &values[1], &values[2]); + + let unit: ValTensor = + Tensor::from(vec![region.assign_constant(&config.inputs[0], F::from(1))?].into_iter()) + .into(); + + // make sure mask is boolean + let assigned_mask = region.assign(&config.inputs[1], mask)?; + + // Enable the selectors + if !region.is_dummy() { + (0..assigned_mask.len()) + .map(|i| { + let (x, y, z) = config.inputs[1].cartesian_coord(region.linear_coord() + i); + let selector = config.selectors.get(&(BaseOp::IsBoolean, x, y)); + region.enable(selector, z)?; + Ok(()) + }) + .collect::, Box>>()?; + } + + region.increment(assigned_mask.len()); + + let one_minus_mask = pairwise(config, region, &[unit, assigned_mask.clone()], BaseOp::Sub)?; + + let masked_a = pairwise(config, region, &[a.clone(), assigned_mask], BaseOp::Mult)?; + + let masked_b = pairwise(config, region, &[b.clone(), one_minus_mask], BaseOp::Mult)?; + + let res = pairwise(config, region, &[masked_a, masked_b], BaseOp::Add)?; + + Ok(res) +} + +/// Negation operation accumulated layout +pub fn neg( + config: &BaseConfig, + region: &mut RegionCtx, + values: &[ValTensor; 1], +) -> Result, Box> { + let input = { + let res = region.assign(&config.inputs[1], &values[0])?; + + res.get_inner()? + }; + + let neg = input.map(|e| -e); + + let output = region.assign(&config.output, &neg.into())?; + + // Enable the selectors + if !region.is_dummy() { + (0..values[0].len()) + .map(|i| { + let (x, y, z) = config.inputs[1].cartesian_coord(region.linear_coord() + i); + let selector = config.selectors.get(&(BaseOp::Neg, x, y)); + + region.enable(selector, z)?; + Ok(()) + }) + .collect::, Box>>()?; + } + + region.increment(output.len()); + + Ok(output) +} + +/// Sumpool accumulated layout +pub fn sumpool( + config: &BaseConfig, + region: &mut RegionCtx, + values: &[ValTensor], + padding: [(usize, usize); 2], + stride: (usize, usize), + kernel_shape: (usize, usize), +) -> Result, Box> { + let batch_size = values[0].dims()[0]; + let image_channels = values[0].dims()[1]; + + let unit = region.assign_constant(&config.inputs[1], F::from(1))?; + region.next(); + + let mut kernel = Tensor::from(0..kernel_shape.0 * kernel_shape.1).map(|_| unit.clone()); + kernel.reshape(&[1, 1, kernel_shape.0, kernel_shape.1])?; + + let cartesian_coord = [(0..batch_size), (0..image_channels)] + .iter() + .cloned() + .multi_cartesian_product() + .collect::>(); + + let mut res = vec![]; + + cartesian_coord + .iter() + .map(|coord| { + let (b, i) = (coord[0], coord[1]); + let input = values[0].get_slice(&[b..b + 1, i..i + 1])?; + let output = conv( + config, + region, + &[input, kernel.clone().into()], + padding, + stride, + )?; + res.push(output); + Ok(()) + }) + .collect::, Box>>()?; + + let shape = &res[0].dims()[2..]; + let mut last_elem = res[1..] + .iter() + .fold(Ok(res[0].clone()), |acc, elem| acc?.concat(elem.clone()))?; + last_elem.reshape(&[&[batch_size, image_channels], shape].concat())?; + + Ok(last_elem) +} + +/// Convolution accumulated layout +pub fn max_pool2d( + config: &BaseConfig, + region: &mut RegionCtx, + values: &[ValTensor; 1], + padding: [(usize, usize); 2], + stride: (usize, usize), + pool_dims: (usize, usize), +) -> Result, Box> { + let image = values[0].clone(); + + if image.dims().len() != 4 { + return Err(Box::new(TensorError::DimMismatch("max_pool2d".to_string()))); + } + let image_dims = image.dims(); + + let (batch, input_channels, image_height, image_width) = + (image_dims[0], image_dims[1], image_dims[2], image_dims[3]); + + let mut padded_image = image.clone(); + padded_image.pad(padding)?; + + let vert_slides = (image_height + padding[0].0 + padding[1].0 - pool_dims.0) / stride.0 + 1; + let horz_slides = (image_width + padding[0].1 + padding[1].1 - pool_dims.1) / stride.1 + 1; + + let mut output: Tensor> = + Tensor::new(None, &[batch, input_channels, horz_slides, vert_slides])?; + + let cartesian_coord = [ + (0..batch), + (0..input_channels), + (0..vert_slides), + (0..horz_slides), + ] + .iter() + .cloned() + .multi_cartesian_product() + .collect::>(); + + output + .iter_mut() + .enumerate() + .map(|(flat_index, o)| { + let coord = &cartesian_coord[flat_index]; + let (b, i, j, k) = (coord[0], coord[1], coord[2], coord[3]); + let rs = j * stride.0; + let cs = k * stride.1; + let slice = padded_image.get_slice(&[ + b..(b + 1), + i..(i + 1), + rs..(rs + pool_dims.0), + cs..(cs + pool_dims.1), + ])?; + let max_w = max(config, region, &[slice])?; + *o = max_w.get_inner_tensor()?[0].clone(); + Ok(()) + }) + .collect::, Box>>()?; + + let res: ValTensor = output.into(); + + Ok(res) +} + +/// DeConvolution accumulated layout +pub fn deconv( + config: &BaseConfig, + region: &mut RegionCtx, + inputs: &[ValTensor], + padding: [(usize, usize); 2], + output_padding: (usize, usize), + stride: (usize, usize), +) -> Result, Box> { + let has_bias = inputs.len() == 3; + let (image, kernel) = (&inputs[0], &inputs[1]); + + if (image.dims().len() != 4) || (kernel.dims().len() != 4) { + return Err(Box::new(TensorError::DimMismatch("deconv".to_string()))); + } + + if stride.0 == 0 || stride.1 == 0 { + return Err(Box::new(TensorError::DimMismatch( + "non-positive stride is not supported for deconv".to_string(), + ))); + } + + if has_bias { + let bias = &inputs[2]; + if (bias.dims().len() != 1) || (bias.dims()[0] != kernel.dims()[0]) { + return Err(Box::new(TensorError::DimMismatch( + "deconv bias".to_string(), + ))); + } + } + + let (kernel_height, kernel_width) = (kernel.dims()[2], kernel.dims()[3]); + + let null_val = ValType::Constant(F::ZERO); + // region.assign_constant(&config.inputs[1], F::from(0))?; + // region.next(); + + let mut expanded_image = image.clone(); + expanded_image.intercalate_values(null_val.clone(), stride.0, 2)?; + expanded_image.intercalate_values(null_val, stride.1, 3)?; + expanded_image.pad([(kernel_height - 1, kernel_width - 1); 2])?; + + // flip order + let channel_coord = (0..kernel.dims()[0]) + .cartesian_product(0..kernel.dims()[1]) + .collect::>(); + + let slice_coord = expanded_image + .dims() + .iter() + .enumerate() + .map(|(i, d)| { + if i == 2 { + padding[0].0..d - padding[1].0 + output_padding.0 + } else if i == 3 { + padding[0].1..d - padding[1].1 + output_padding.1 + } else { + 0..*d + } + }) + .collect::>(); + + let sliced_expanded_image = expanded_image.get_slice(&slice_coord)?; + + let mut inverted_kernels = vec![]; + + for (i, j) in channel_coord { + let channel = kernel.get_slice(&[i..i + 1, j..j + 1])?; + let mut channel = Tensor::from(channel.get_inner_tensor()?.clone().into_iter().rev()); + channel.reshape(&[kernel.dims()[2], kernel.dims()[3]])?; + inverted_kernels.push(channel); + } + + let mut deconv_kernel = + Tensor::new(Some(&inverted_kernels), &[inverted_kernels.len()])?.combine()?; + deconv_kernel.reshape(kernel.dims())?; + + // tensorflow formatting patch + if kernel.dims()[0] == sliced_expanded_image.dims()[1] { + deconv_kernel.reshape(&[ + kernel.dims()[1], + kernel.dims()[0], + kernel.dims()[2], + kernel.dims()[3], + ])?; + } + + let conv_input = if has_bias { + vec![ + sliced_expanded_image, + deconv_kernel.clone().into(), + inputs[2].clone(), + ] + } else { + vec![sliced_expanded_image, deconv_kernel.clone().into()] + }; + + let output = conv(config, region, &conv_input, [(0, 0); 2], (1, 1))?; + + Ok(output) +} + +/// Convolution accumulated layout +pub fn conv( + config: &BaseConfig, + region: &mut RegionCtx, + values: &[ValTensor], + padding: [(usize, usize); 2], + stride: (usize, usize), +) -> Result, Box> { + let has_bias = values.len() == 3; + let (mut image, mut kernel) = (values[0].clone(), values[1].clone()); + + // we specifically want to use the same kernel and image for all the convolutions and need to enforce this by assigning them + // 1. assign the kernel + let mut assigned_len = vec![]; + + if !kernel.all_prev_assigned() { + kernel = region.assign(&config.inputs[0], &kernel)?; + assigned_len.push(kernel.len()); + } + // 2. assign the image + if !image.all_prev_assigned() { + image = region.assign(&config.inputs[1], &image)?; + assigned_len.push(image.len()); + } + + if !assigned_len.is_empty() { + region.increment(*assigned_len.iter().max().ok_or(TensorError::DimError)?); + } + + let og_image_dims = image.dims().to_vec(); + let og_kernel_dims = kernel.dims().to_vec(); + // ensure inputs are 4D tensors + if og_image_dims.len() == 3 { + // adds a dummy image_channels dimension + let mut new_dims = image.dims().to_vec(); + // insert 1 at the input_channels pos + if og_kernel_dims.len() == 3 { + new_dims.insert(1, 1); + } else { + new_dims.insert(0, 1); + } + image.reshape(&new_dims)?; + } + + // ensure kernel is 4D tensor + if og_kernel_dims.len() == 3 && og_image_dims.len() == 3 { + // adds a dummy image_channels dimension + let mut new_dims = kernel.dims().to_vec(); + // insert 1 at the input_channels pos + new_dims.insert(1, 1); + kernel.reshape(&new_dims)?; + } + + // if not 4D then error + if (image.dims().len() != 4) + || (kernel.dims().len() != 4) + || ((image.dims()[1] != kernel.dims()[1]) && (kernel.dims()[1] != 1)) + { + return Err(Box::new(TensorError::DimMismatch("conv".to_string()))); + } + + let image_dims = image.dims(); + let kernel_dims = kernel.dims(); + + let mut padded_image = image.clone(); + padded_image.pad(padding)?; + + let (batch_size, output_channels, input_channels, kernel_height, kernel_width) = ( + image_dims[0], + kernel_dims[0], + image_dims[1], + kernel_dims[2], + kernel_dims[3], + ); + + let (image_height, image_width) = (image_dims[2], image_dims[3]); + + let vert_slides = (image_height + padding[0].0 + padding[1].0 - kernel_height) / stride.0 + 1; + let horz_slides = (image_width + padding[0].1 + padding[1].1 - kernel_width) / stride.1 + 1; + + let num_groups = input_channels / kernel_dims[1]; + let input_channels_per_group = input_channels / num_groups; + let output_channels_per_group = output_channels / num_groups; + + if output_channels_per_group == 0 { + return Err(Box::new(TensorError::DimMismatch(format!( + "Given groups={}, expected kernel to be at least {} at dimension 0 but got {} instead", + num_groups, num_groups, output_channels_per_group + )))); + } + + let num_outputs = + batch_size * num_groups * output_channels_per_group * vert_slides * horz_slides; + + let mut output: Tensor> = Tensor::new(None, &[num_outputs])?; + + let cartesian_coord = [ + (0..batch_size), + (0..num_groups), + (0..output_channels_per_group), + (0..vert_slides), + (0..horz_slides), + ] + .iter() + .cloned() + .multi_cartesian_product() + .collect::>(); + + let inner_loop_function = |idx: usize, region: &mut RegionCtx| { + let cartesian_coord_per_group = &cartesian_coord[idx]; + let (batch, group, i, j, k) = ( + cartesian_coord_per_group[0], + cartesian_coord_per_group[1], + cartesian_coord_per_group[2], + cartesian_coord_per_group[3], + cartesian_coord_per_group[4], + ); + let rs = j * stride.0; + let cs = k * stride.1; + + let start_channel = group * input_channels_per_group; + let end_channel = start_channel + input_channels_per_group; + + let mut local_image = padded_image.get_slice(&[ + batch..batch + 1, + start_channel..end_channel, + rs..(rs + kernel_height), + cs..(cs + kernel_width), + ])?; + + local_image.flatten(); + + let start_kernel_index = group * output_channels_per_group + i; + let end_kernel_index = start_kernel_index + 1; + let mut local_kernel = kernel.get_slice(&[start_kernel_index..end_kernel_index])?; + + local_kernel.flatten(); + + // this is dot product notation in einsum format + let mut res = einsum(config, region, &[local_image, local_kernel], "i,i->")?; + + if has_bias { + let bias = values[2].get_single_elem(start_kernel_index)?; + res = pairwise(config, region, &[res, bias], BaseOp::Add)?; + } + region.flush()?; + + Ok(res.get_inner_tensor()?[0].clone()) + }; + + region.flush()?; + region.apply_in_loop(&mut output, inner_loop_function)?; + + let reshape_output = |output: &mut Tensor>| -> Result<(), TensorError> { + // remove dummy batch dimension if we added one + if og_image_dims.len() == 3 && vert_slides == 1 { + output.reshape(&[batch_size, output_channels, horz_slides])?; + } else if og_image_dims.len() == 3 { + output.reshape(&[output_channels, vert_slides, horz_slides])?; + } else { + output.reshape(&[batch_size, output_channels, vert_slides, horz_slides])?; + } + Ok(()) + }; + + // remove dummy batch dimension if we added one + reshape_output(&mut output)?; + + let output: ValTensor<_> = output.into(); + + Ok(output) +} + +/// Power accumulated layout +pub fn pow( + config: &BaseConfig, + region: &mut RegionCtx, + values: &[ValTensor; 1], + exponent: u32, +) -> Result, Box> { + let mut t = values[0].clone(); + + for _ in 1..exponent { + t = pairwise(config, region, &[t, values[0].clone()], BaseOp::Mult)?; + } + + Ok(t) +} + +/// Rescaled op accumulated layout +pub fn rescale( + config: &BaseConfig, + region: &mut RegionCtx, + values: &[ValTensor], + scales: &[(usize, u128)], +) -> Result>, Box> { + let mut rescaled_inputs = vec![]; + for (i, ri) in values.iter().enumerate() { + if scales[i].1 == 1 { + rescaled_inputs.push(ri.clone()); + continue; + } + + let multiplier: ValTensor = + Tensor::from(vec![ValType::Constant(F::from(scales[i].1 as u64))].into_iter()).into(); + let scaled_input = pairwise(config, region, &[ri.clone(), multiplier], BaseOp::Mult)?; + rescaled_inputs.push(scaled_input); + } + + Ok(rescaled_inputs) +} + +/// Pack accumulated layout +pub fn pack( + config: &BaseConfig, + region: &mut RegionCtx, + values: &[ValTensor; 1], + base: u32, + scale: u32, +) -> Result, Box> { + let mut t = values[0].clone(); + t.flatten(); + + // these unwraps should never ever fail if the Tensortypes are correctly implemented + // if anything we want these to hard fail if not implemented + let mut base_t = ::zero().ok_or(TensorError::FeltError)?; + for _ in 0..base { + base_t += ::one().ok_or(TensorError::FeltError)?; + } + let mut accum_base = vec![]; + let base_tensor = Tensor::new(Some(&[base_t]), &[1])?; + for i in 0..t.dims().iter().product::() { + accum_base.push(Value::known(base_tensor.pow((i as u32) * (scale + 1))?[0])); + } + + let base_tensor = Tensor::new(Some(&accum_base), &[accum_base.len()])?; + let base_prod = pairwise( + config, + region, + &[t.clone(), base_tensor.into()], + BaseOp::Mult, + )?; + + let res = sum(config, region, &[base_prod])?; + + Ok(res) +} + +/// Dummy (no contraints) reshape layout +pub fn reshape( + values: &[ValTensor; 1], + new_dims: &[usize], +) -> Result, Box> { + let mut t = values[0].clone(); + t.reshape(new_dims)?; + Ok(t) +} + +/// Dummy (no contraints) move_axis layout +pub fn move_axis( + values: &[ValTensor; 1], + source: usize, + destination: usize, +) -> Result, Box> { + let mut t = values[0].clone(); + t.move_axis(source, destination)?; + Ok(t) +} + +/// resize layout +pub fn resize( + config: &BaseConfig, + region: &mut RegionCtx, + values: &[ValTensor; 1], + scales: &[usize], +) -> Result, Box> { + let mut output = region.assign(&config.output, &values[0])?; + region.increment(output.len()); + output.resize(scales)?; + + Ok(output) +} + +/// Slice layout +pub fn slice( + config: &BaseConfig, + region: &mut RegionCtx, + values: &[ValTensor; 1], + axis: &usize, + start: &usize, + end: &usize, +) -> Result, Box> { + // assigns the instance to the advice. + let mut output = region.assign(&config.output, &values[0])?; + region.increment(output.len()); + output.slice(axis, start, end)?; + + Ok(output) +} + +/// Concat layout +pub fn concat( + values: &[ValTensor], + axis: &usize, +) -> Result, Box> { + let collected_inner: Result>, _> = + values.iter().map(|e| e.get_inner_tensor()).collect(); + let collected_inner = collected_inner?; + + Ok(tensor::ops::concat(&collected_inner, *axis)?.into()) +} + +/// Identity constraint. Usually used to constrain an instance column to an advice so the returned cells / values can be operated upon. +pub fn identity( + config: &BaseConfig, + region: &mut RegionCtx, + values: &[ValTensor; 1], +) -> Result, Box> { + let mut output = values[0].clone(); + if !output.all_prev_assigned() { + output = region.assign(&config.output, &values[0])?; + region.increment(output.len()); + } + + Ok(output) +} + +/// Boolean identity constraint. Usually used to constrain an instance column to an advice so the returned cells / values can be operated upon. +pub fn boolean_identity( + config: &BaseConfig, + region: &mut RegionCtx, + values: &[ValTensor; 1], +) -> Result, Box> { + let output = region.assign(&config.inputs[1], &values[0])?; + // Enable the selectors + if !region.is_dummy() { + (0..output.len()) + .map(|j| { + let (x, y, z) = config.inputs[1].cartesian_coord(region.linear_coord() + j); + let selector = config.selectors.get(&(BaseOp::IsBoolean, x, y)); + + region.enable(selector, z)?; + Ok(()) + }) + .collect::, Box>>()?; + } + region.increment(output.len()); + + Ok(output) +} + +/// Downsample layout +pub fn downsample( + config: &BaseConfig, + region: &mut RegionCtx, + values: &[ValTensor; 1], + axis: &usize, + stride: &usize, + modulo: &usize, +) -> Result, Box> { + let input = region.assign(&config.inputs[0], &values[0])?; + let processed_output = + tensor::ops::downsample(input.get_inner_tensor()?, *axis, *stride, *modulo)?; + let output = region.assign(&config.output, &processed_output.into())?; + region.increment(std::cmp::max(input.len(), output.len())); + Ok(output) +} + +/// layout for enforcing two sets of cells to be equal +pub fn enforce_equality( + config: &BaseConfig, + region: &mut RegionCtx, + values: &[ValTensor; 2], +) -> Result, Box> { + // assert of same len + + // assigns the instance to the advice. + let input = region.assign(&config.inputs[1], &values[0])?; + let output = region.assign(&config.output, &values[1])?; + + if !region.is_dummy() { + region.constrain_equal(&input, &output)?; + } + + region.increment(output.len()); + + Ok(output) +} + +/// layout for nonlinearity check. +pub fn nonlinearity( + config: &BaseConfig, + region: &mut RegionCtx, + values: &[ValTensor; 1], + nl: &LookupOp, +) -> Result, Box> { + // time the entire operation + let timer = instant::Instant::now(); + + let x = values[0].clone(); + + let removal_indices = values[0].get_const_indices()?; + let removal_indices: HashSet<&usize> = HashSet::from_iter(removal_indices.iter()); + let removal_indices_ptr = &removal_indices; + + let w = region.assign_with_omissions(&config.lookup_input, &x, removal_indices_ptr)?; + + let output = w.get_inner_tensor()?.par_enum_map(|i, e| { + Ok::<_, TensorError>(if let Some(f) = e.get_felt_eval() { + if !removal_indices.contains(&i) { + Value::known(Op::::f(nl, &[Tensor::from(vec![f].into_iter())])?.output[0]).into() + } else { + ValType::Constant(Op::::f(nl, &[Tensor::from(vec![f].into_iter())])?.output[0]) + } + } else { + Value::::unknown().into() + }) + })?; + + let assigned_len = x.len() - removal_indices.len(); + let mut output = + region.assign_with_omissions(&config.lookup_output, &output.into(), removal_indices_ptr)?; + + let is_dummy = region.is_dummy(); + + let table_index: ValTensor = w + .get_inner_tensor()? + .par_enum_map(|i, e| { + Ok::<_, TensorError>(if let Some(f) = e.get_felt_eval() { + let col_idx = if !is_dummy { + let table = config.tables.get(nl).ok_or(TensorError::TableLookupError)?; + table.get_col_index(f) + } else { + F::ZERO + }; + if !removal_indices.contains(&i) { + Value::known(col_idx).into() + } else { + ValType::Constant(col_idx) + } + } else { + Value::::unknown().into() + }) + })? + .into(); + + region.assign_with_omissions(&config.lookup_index, &table_index, removal_indices_ptr)?; + + if !is_dummy { + (0..assigned_len) + .map(|i| { + let (x, y, z) = config + .lookup_input + .cartesian_coord(region.linear_coord() + i); + let selector = config.lookup_selectors.get(&(nl.clone(), x, y)); + region.enable(selector, z)?; + Ok(()) + }) + .collect::, Box>>()?; + } + + region.increment(assigned_len); + + output.reshape(x.dims())?; + + let elapsed = timer.elapsed(); + trace!( + "nonlinearity {} layout took {:?}, row: {:?}", + >::as_string(nl), + elapsed, + region.row() + ); + + // constrain the calculated output to a column + Ok(output) +} + +/// mean function layout +pub fn mean( + config: &BaseConfig, + region: &mut RegionCtx, + values: &[ValTensor; 1], + scale: usize, +) -> Result, Box> { + let x = &values[0]; + + let sum_x = sum(config, region, &[x.clone()])?; + let nl = LookupOp::Div { + denom: utils::F32((scale * x.len()) as f32), + }; + nonlinearity(config, region, &[sum_x], &nl) +} + +/// Argmax +pub fn argmax( + config: &BaseConfig, + region: &mut RegionCtx, + values: &[ValTensor; 1], + indices: ValTensor, +) -> Result, Box> { + // this is safe because we later constrain it + let argmax = values[0] + .get_int_evals()? + .into_par_iter() + .enumerate() + // we value the first index in the case of a tie + .max_by_key(|(idx, value)| (*value, -(*idx as i64))) + .map(|(idx, _)| idx as i128); + let argmax_val: ValTensor = match argmax { + None => Tensor::new(Some(&[Value::::unknown()]), &[1])?.into(), + Some(i) => Tensor::new(Some(&[Value::known(i128_to_felt::(i))]), &[1])?.into(), + }; + + let assigned_argmax: ValTensor = region.assign(&config.inputs[1], &argmax_val)?; + region.increment(assigned_argmax.len()); + + let claimed_val = select( + config, + region, + &[values[0].clone(), assigned_argmax.clone()], + indices, + )?; + + let max_val = max(config, region, &[values[0].clone()])?; + + enforce_equality(config, region, &[claimed_val, max_val])?; + + Ok(assigned_argmax) +} + +/// Argmin +pub fn argmin( + config: &BaseConfig, + region: &mut RegionCtx, + values: &[ValTensor; 1], + indices: ValTensor, +) -> Result, Box> { + // this is safe because we later constrain it + let argmin = values[0] + .get_int_evals()? + .into_par_iter() + .enumerate() + // we value the first index in the case of a tie + .min_by_key(|(idx, value)| (*value, (*idx as i64))) + .map(|(idx, _)| idx as i128); + let argmin_val: ValTensor = match argmin { + None => Tensor::new(Some(&[Value::::unknown()]), &[1])?.into(), + Some(i) => Tensor::new(Some(&[Value::known(i128_to_felt::(i))]), &[1])?.into(), + }; + + let assigned_argmin: ValTensor = region.assign(&config.inputs[1], &argmin_val)?; + region.increment(assigned_argmin.len()); + + // these will be assigned as constants + let claimed_val = select( + config, + region, + &[values[0].clone(), assigned_argmin.clone()], + indices, + )?; + let min_val = min(config, region, &[values[0].clone()])?; + + enforce_equality(config, region, &[claimed_val, min_val])?; + + Ok(assigned_argmin) +} + +/// max layout +pub fn max( + config: &BaseConfig, + region: &mut RegionCtx, + values: &[ValTensor; 1], +) -> Result, Box> { + // this is safe because we later constrain it + let max_int = values[0].get_int_evals()?.into_par_iter().max(); + let max_val: ValTensor = match max_int { + None => Tensor::new(Some(&[Value::::unknown()]), &[1])?.into(), + Some(i) => Tensor::new(Some(&[Value::known(i128_to_felt::(i))]), &[1])?.into(), + }; + + let assigned_max_val: ValTensor = region.assign(&config.inputs[1], &max_val)?; + region.increment(assigned_max_val.len()); + + let unit: ValTensor = + Tensor::from(vec![region.assign_constant(&config.inputs[1], F::from(1))?].into_iter()) + .into(); + region.next(); + + // max(x - 1) + let max_minus_1 = pairwise( + config, + region, + &[assigned_max_val.clone(), unit.clone()], + BaseOp::Sub, + )?; + + // x - max(x - 1) + let diff = pairwise( + config, + region, + &[values[0].clone(), max_minus_1], + BaseOp::Sub, + )?; + // relu(x - max(x - 1)) + let relu = nonlinearity(config, region, &[diff], &LookupOp::ReLU)?; + + let len = relu.dims().iter().product(); + + // y_i*(1 - y_i) =0 // assert the values are either 0 or 1 + region.assign(&config.inputs[1], &relu)?; + + if !region.is_dummy() { + (0..len) + .map(|i| { + let (x, y, z) = config.inputs[1].cartesian_coord(region.linear_coord() + i); + let selector = config.selectors.get(&(BaseOp::IsBoolean, x, y)); + region.enable(selector, z)?; + Ok(()) + }) + .collect::, Box>>()?; + } + + region.increment(len); + + // sum(relu(x - max(x - 1))) + let sum_relu = sum(config, region, &[relu])?; + // 1 - sum(relu(x - max(x - 1))) + let one_minus_sum_relu = pairwise(config, region, &[unit, sum_relu], BaseOp::Sub)?; + // relu(1 - sum(relu(x - max(x - 1)))) + let relu_one_minus_sum_relu = + nonlinearity(config, region, &[one_minus_sum_relu], &LookupOp::ReLU)?; + + // constraining 1 - sum(relu(x - max(x - 1))) = 0 + region.assign(&config.inputs[1], &relu_one_minus_sum_relu)?; + + let (x, y, z) = config.output.cartesian_coord(region.linear_coord()); + let selector = config.selectors.get(&(BaseOp::IsZero, x, y)); + region.enable(selector, z)?; + + region.increment(relu_one_minus_sum_relu.len()); + + Ok(assigned_max_val) +} + +/// min layout +pub fn min( + config: &BaseConfig, + region: &mut RegionCtx, + values: &[ValTensor; 1], +) -> Result, Box> { + // this is safe because we later constrain it + + let min_int = values[0].get_int_evals()?.into_par_iter().min(); + let min_val: ValTensor = match min_int { + None => Tensor::new(Some(&[Value::::unknown()]), &[1])?.into(), + Some(i) => Tensor::new(Some(&[Value::known(i128_to_felt::(i))]), &[1])?.into(), + }; + + let assigned_min_val = region.assign(&config.inputs[1], &min_val)?; + region.increment(assigned_min_val.len()); + + let unit: ValTensor = + Tensor::from(vec![region.assign_constant(&config.inputs[1], F::from(1))?].into_iter()) + .into(); + region.next(); + + // min(x + 1) + let min_plus_1 = pairwise( + config, + region, + &[assigned_min_val.clone(), unit.clone()], + BaseOp::Add, + )?; + + // min(x + 1) - x + let diff = pairwise( + config, + region, + &[min_plus_1, values[0].clone()], + BaseOp::Sub, + )?; + + // relu(min(x + 1) - x) + let relu = nonlinearity(config, region, &[diff], &LookupOp::ReLU)?; + + let len = relu.dims().iter().product(); + + region.assign(&config.inputs[1], &relu)?; + // y_i*(1 - y_i) =0 // assert the values are either 0 or 1 + if !region.is_dummy() { + (0..len) + .map(|i| { + let (x, y, z) = config.inputs[1].cartesian_coord(region.linear_coord() + i); + let selector = config.selectors.get(&(BaseOp::IsBoolean, x, y)); + region.enable(selector, z)?; + Ok(()) + }) + .collect::, Box>>()?; + } + + region.increment(len); + + // sum(relu(min(x + 1) - x)) + let sum_relu = sum(config, region, &[relu])?; + // 1 - sum(relu(min(x + 1) - x)) + let one_minus_sum_relu = pairwise(config, region, &[unit, sum_relu], BaseOp::Sub)?; + // relu(1 - sum(relu(min(x + 1) - x))) + + let relu_one_minus_sum_relu = + nonlinearity(config, region, &[one_minus_sum_relu], &LookupOp::ReLU)?; + + region.assign(&config.inputs[1], &relu_one_minus_sum_relu)?; + + // constraining product to 0 + let (x, y, z) = config.output.cartesian_coord(region.linear_coord()); + let selector = config.selectors.get(&(BaseOp::IsZero, x, y)); + region.enable(selector, z)?; + + region.increment(relu_one_minus_sum_relu.len()); + + Ok(assigned_min_val) +} + +fn multi_dim_axes_op( + config: &BaseConfig, + region: &mut RegionCtx, + values: &[ValTensor; 1], + axes: &[usize], + op: impl Fn( + &BaseConfig, + &mut RegionCtx, + &[ValTensor; 1], + ) -> Result, Box> + + Send + + Sync, +) -> Result, Box> { + let mut input = values[0].clone(); + + if !input.all_prev_assigned() { + input = region.assign(&config.inputs[0], &input)?; + region.increment(input.len()); + } + + if input.dims().len() == 1 { + return op(config, region, &[input]); + } + + // Calculate the output tensor size + let input_dims = input.dims(); + + let mut sorted_axes = axes.to_vec(); + // descending order + sorted_axes.sort_by(|x, y| y.cmp(x)); + + let mut output_size_without_dim = input_dims.to_vec(); + for dim in &sorted_axes { + output_size_without_dim.remove(*dim); + } + + let mut op_tensors = Tensor::>::new(None, &output_size_without_dim)?; + + // Allocate memory for the output tensor + let cartesian_coord = output_size_without_dim + .iter() + .map(|x| 0..*x) + .multi_cartesian_product() + .collect::>(); + + let inner_loop_function = |i: usize, region: &mut RegionCtx| { + let coord = cartesian_coord[i].clone(); + let mut slice = coord.iter().map(|x| *x..*x + 1).collect::>(); + + for dim in &sorted_axes { + slice.insert(*dim, 0..input_dims[*dim]); + } + + let mut sliced_input = input.get_slice(&slice)?; + sliced_input.flatten(); + + Ok(op(config, region, &[sliced_input])?) + }; + + region.apply_in_loop(&mut op_tensors, inner_loop_function)?; + + // assert all op_tensors have the same dims + let sample_op_output_size = op_tensors[0].dims(); + + // now deduce the output size from the dims of the output tensors + let mut output_size = input_dims.to_vec(); + for dim in axes.iter().enumerate() { + output_size[*dim.1] = sample_op_output_size[dim.0]; + } + + // Allocate memory for the output tensor + let cartesian_coord = output_size + .iter() + .map(|x| 0..*x) + .multi_cartesian_product() + .collect::>(); + + let mut output = Tensor::>::new(None, &output_size)?; + + output = output.par_enum_map(|i, _| { + let coord = cartesian_coord[i].clone(); + let mut op_idx = coord.clone(); + let mut coord_at_dims = vec![]; + for dim in &sorted_axes { + op_idx.remove(*dim); + } + for dim in axes { + coord_at_dims.push(coord[*dim]); + } + + let topk_elem = op_tensors + .get(&op_idx) + .get_inner_tensor()? + .get(&coord_at_dims) + .clone(); + + Ok::<_, region::RegionError>(topk_elem) + })?; + + Ok(output.into()) +} + +/// softmax layout +pub fn softmax_axes( + config: &BaseConfig, + region: &mut RegionCtx, + values: &[ValTensor; 1], + scale: utils::F32, + axes: &[usize], +) -> Result, Box> { + let soft_max_at_scale = move |config: &BaseConfig, + region: &mut RegionCtx, + values: &[ValTensor; 1]| + -> Result, Box> { + softmax(config, region, values, scale) + }; + + let output = multi_dim_axes_op(config, region, values, axes, soft_max_at_scale)?; + + Ok(output) +} + +/// softmax func +pub fn softmax( + config: &BaseConfig, + region: &mut RegionCtx, + values: &[ValTensor; 1], + scale: utils::F32, +) -> Result, Box> { + // elementwise exponential + let ex = nonlinearity(config, region, values, &LookupOp::Exp { scale })?; + + // sum of exps + let denom = sum(config, region, &[ex.clone()])?; + // get the inverse + + let inv_denom = nonlinearity( + config, + region, + &[denom], + // we set to input scale + output_scale so the output scale is output)scale + &LookupOp::Recip { + scale: scale.0.powf(2.0).into(), + }, + )?; + + // product of num * (1 / denom) = 2*output_scale + let softmax = pairwise(config, region, &[ex, inv_denom], BaseOp::Mult)?; + + Ok(softmax) +} + +/// Checks that the percent error between the expected public output and the actual output value +/// is within the percent error expressed by the `tol` input, where `tol == 1.0` means the percent +/// error tolerance is 1 percent. +pub fn range_check_percent( + config: &BaseConfig, + region: &mut RegionCtx, + values: &[ValTensor; 2], + scale: utils::F32, + tol: f32, +) -> Result, Box> { + if tol == 0.0 { + // regular equality constraint + return enforce_equality(config, region, values); + } + + // Calculate the difference between the expected output and actual output + let diff = pairwise(config, region, values, BaseOp::Sub)?; + + let scale_squared = scale.0.powf(2.0); + // Calculate the reciprocal of the expected output tensor, scaling by double the scaling factor + let recip = nonlinearity( + config, + region, + &[values[0].clone()], + &LookupOp::Recip { + scale: scale_squared.into(), + }, + )?; + // Multiply the difference by the recip + let product = pairwise(config, region, &[diff, recip], BaseOp::Mult)?; + + // Use the greater than look up table to check if the percent error is within the tolerance for upper bound + let tol = tol / 100.0; + let upper_bound = nonlinearity( + config, + region, + &[product.clone()], + &LookupOp::GreaterThan { + a: utils::F32(tol * scale_squared), + }, + )?; + + // Negate the product + let neg_product = neg(config, region, &[product])?; + + // Use the greater than look up table to check if the percent error is within the tolerance for lower bound + let lower_bound = nonlinearity( + config, + region, + &[neg_product], + &LookupOp::GreaterThan { + a: utils::F32(tol * scale_squared), + }, + )?; + + // Add the lower_bound and upper_bound + let sum = pairwise(config, region, &[lower_bound, upper_bound], BaseOp::Add)?; + + // Assign the sum tensor to the inputs + region.assign(&config.inputs[1], &sum)?; + + // Constrain the sum to be all zeros + let (x, y, z) = config.output.cartesian_coord(region.linear_coord()); + let selector = config.selectors.get(&(BaseOp::IsZero, x, y)); + region.enable(selector, z)?; + + region.increment(sum.len()); + + Ok(sum) +} diff --git a/mnist_ezkl/src/circuit/ops/lookup.rs b/mnist_ezkl/src/circuit/ops/lookup.rs new file mode 100644 index 0000000..a927e8e --- /dev/null +++ b/mnist_ezkl/src/circuit/ops/lookup.rs @@ -0,0 +1,317 @@ +use super::*; +use serde::{Deserialize, Serialize}; +use std::error::Error; + +use crate::{ + circuit::{layouts, utils}, + fieldutils::{felt_to_i128, i128_to_felt}, + graph::{multiplier_to_scale, scale_to_multiplier}, + tensor::{self, Tensor, TensorError, TensorType}, +}; + +use super::Op; +use halo2curves::ff::PrimeField; + +#[allow(missing_docs)] +/// An enum representing the operations that can be used to express more complex operations via accumulation +#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Deserialize, Serialize)] +pub enum LookupOp { + Abs, + Div { + denom: utils::F32, + }, + ReLU, + Max { + scales: (usize, usize), + a: utils::F32, + }, + Min { + scales: (usize, usize), + a: utils::F32, + }, + Ceil { + scale: utils::F32, + }, + Floor { + scale: utils::F32, + }, + Round { + scale: utils::F32, + }, + RoundHalfToEven { + scale: utils::F32, + }, + Sqrt { + scale: utils::F32, + }, + Rsqrt { + scale: utils::F32, + }, + Recip { + scale: utils::F32, + }, + LeakyReLU { + slope: utils::F32, + }, + Sigmoid { + scale: utils::F32, + }, + Ln { + scale: utils::F32, + }, + Exp { + scale: utils::F32, + }, + Cos { + scale: utils::F32, + }, + ACos { + scale: utils::F32, + }, + Cosh { + scale: utils::F32, + }, + ACosh { + scale: utils::F32, + }, + Sin { + scale: utils::F32, + }, + ASin { + scale: utils::F32, + }, + Sinh { + scale: utils::F32, + }, + ASinh { + scale: utils::F32, + }, + Tan { + scale: utils::F32, + }, + ATan { + scale: utils::F32, + }, + Tanh { + scale: utils::F32, + }, + ATanh { + scale: utils::F32, + }, + Erf { + scale: utils::F32, + }, + GreaterThan { + a: utils::F32, + }, + LessThan { + a: utils::F32, + }, + GreaterThanEqual { + a: utils::F32, + }, + LessThanEqual { + a: utils::F32, + }, + Sign, + KroneckerDelta, + Pow { + scale: utils::F32, + a: utils::F32, + }, +} + +impl LookupOp { + /// Returns the range of values that can be represented by the table + pub fn bit_range(max_len: usize) -> (i128, i128) { + let range = (max_len - 1) as f64 / 2_f64; + let range = range as i128; + (-range, range) + } +} + +impl Op for LookupOp { + /// Returns a reference to the Any trait. + fn as_any(&self) -> &dyn Any { + self + } + /// Matches a [Op] to an operation in the `tensor::ops` module. + fn f(&self, x: &[Tensor]) -> Result, TensorError> { + let x = x[0].clone().map(|x| felt_to_i128(x)); + let res = match &self { + LookupOp::Abs => Ok(tensor::ops::abs(&x)?), + LookupOp::Ceil { scale } => Ok(tensor::ops::nonlinearities::ceil(&x, scale.into())), + LookupOp::Floor { scale } => Ok(tensor::ops::nonlinearities::floor(&x, scale.into())), + LookupOp::Round { scale } => Ok(tensor::ops::nonlinearities::round(&x, scale.into())), + LookupOp::RoundHalfToEven { scale } => Ok( + tensor::ops::nonlinearities::round_half_to_even(&x, scale.into()), + ), + LookupOp::Pow { scale, a } => Ok(tensor::ops::nonlinearities::pow( + &x, + scale.0.into(), + a.0.into(), + )), + LookupOp::KroneckerDelta => Ok(tensor::ops::nonlinearities::kronecker_delta(&x)), + LookupOp::Max { scales, a } => Ok(tensor::ops::nonlinearities::max( + &x, + scales.0, + scales.1, + a.0.into(), + )), + LookupOp::Min { scales, a } => Ok(tensor::ops::nonlinearities::min( + &x, + scales.0, + scales.1, + a.0.into(), + )), + LookupOp::Sign => Ok(tensor::ops::nonlinearities::sign(&x)), + LookupOp::LessThan { a } => Ok(tensor::ops::nonlinearities::less_than( + &x, + f32::from(*a).into(), + )), + LookupOp::LessThanEqual { a } => Ok(tensor::ops::nonlinearities::less_than_equal( + &x, + f32::from(*a).into(), + )), + LookupOp::GreaterThan { a } => Ok(tensor::ops::nonlinearities::greater_than( + &x, + f32::from(*a).into(), + )), + LookupOp::GreaterThanEqual { a } => Ok( + tensor::ops::nonlinearities::greater_than_equal(&x, f32::from(*a).into()), + ), + LookupOp::Div { denom } => Ok(tensor::ops::nonlinearities::const_div( + &x, + f32::from(*denom).into(), + )), + LookupOp::Recip { scale } => Ok(tensor::ops::nonlinearities::recip(&x, scale.into())), + LookupOp::ReLU => Ok(tensor::ops::nonlinearities::leakyrelu(&x, 0_f64)), + + LookupOp::LeakyReLU { slope: a } => { + Ok(tensor::ops::nonlinearities::leakyrelu(&x, a.0.into())) + } + LookupOp::Sigmoid { scale } => { + Ok(tensor::ops::nonlinearities::sigmoid(&x, scale.into())) + } + LookupOp::Sqrt { scale } => Ok(tensor::ops::nonlinearities::sqrt(&x, scale.into())), + LookupOp::Rsqrt { scale } => Ok(tensor::ops::nonlinearities::rsqrt(&x, scale.into())), + LookupOp::Erf { scale } => Ok(tensor::ops::nonlinearities::erffunc(&x, scale.into())), + LookupOp::Exp { scale } => Ok(tensor::ops::nonlinearities::exp(&x, scale.into())), + LookupOp::Ln { scale } => Ok(tensor::ops::nonlinearities::ln(&x, scale.into())), + LookupOp::Cos { scale } => Ok(tensor::ops::nonlinearities::cos(&x, scale.into())), + LookupOp::ACos { scale } => Ok(tensor::ops::nonlinearities::acos(&x, scale.into())), + LookupOp::Cosh { scale } => Ok(tensor::ops::nonlinearities::cosh(&x, scale.into())), + LookupOp::ACosh { scale } => Ok(tensor::ops::nonlinearities::acosh(&x, scale.into())), + LookupOp::Sin { scale } => Ok(tensor::ops::nonlinearities::sin(&x, scale.into())), + LookupOp::ASin { scale } => Ok(tensor::ops::nonlinearities::asin(&x, scale.into())), + LookupOp::Sinh { scale } => Ok(tensor::ops::nonlinearities::sinh(&x, scale.into())), + LookupOp::ASinh { scale } => Ok(tensor::ops::nonlinearities::asinh(&x, scale.into())), + LookupOp::Tan { scale } => Ok(tensor::ops::nonlinearities::tan(&x, scale.into())), + LookupOp::ATan { scale } => Ok(tensor::ops::nonlinearities::atan(&x, scale.into())), + LookupOp::ATanh { scale } => Ok(tensor::ops::nonlinearities::atanh(&x, scale.into())), + LookupOp::Tanh { scale } => Ok(tensor::ops::nonlinearities::tanh(&x, scale.into())), + }?; + + let output = res.map(|x| i128_to_felt(x)); + + Ok(ForwardResult { + output, + intermediate_lookups: vec![], + }) + } + + /// Returns the name of the operation + fn as_string(&self) -> String { + match self { + LookupOp::Abs => "ABS".into(), + LookupOp::Ceil { scale } => format!("CEIL(scale={})", scale), + LookupOp::Floor { scale } => format!("FLOOR(scale={})", scale), + LookupOp::Round { scale } => format!("ROUND(scale={})", scale), + LookupOp::RoundHalfToEven { scale } => format!("ROUND_HALF_TO_EVEN(scale={})", scale), + LookupOp::Pow { a, scale } => format!("POW(scale={}, exponent={})", scale, a), + LookupOp::KroneckerDelta => "K_DELTA".into(), + LookupOp::Max { scales, a } => format!("MAX(scales={:?}, a={})", scales, a), + LookupOp::Min { scales, a } => format!("MIN(scales={:?}, a={})", scales, a), + LookupOp::Sign => "SIGN".into(), + LookupOp::GreaterThan { .. } => "GREATER_THAN".into(), + LookupOp::GreaterThanEqual { .. } => "GREATER_THAN_EQUAL".into(), + LookupOp::LessThan { .. } => "LESS_THAN".into(), + LookupOp::LessThanEqual { .. } => "LESS_THAN_EQUAL".into(), + LookupOp::Recip { scale, .. } => format!("RECIP(scale={})", scale), + LookupOp::Div { denom, .. } => format!("DIV(denom={})", denom), + LookupOp::Ln { scale } => format!("LN(scale={})", scale), + LookupOp::ReLU => "RELU".to_string(), + LookupOp::LeakyReLU { slope: a } => format!("L_RELU(slope={})", a), + LookupOp::Sigmoid { scale } => format!("SIGMOID(scale={})", scale), + LookupOp::Sqrt { scale } => format!("SQRT(scale={})", scale), + LookupOp::Erf { scale } => format!("ERF(scale={})", scale), + LookupOp::Rsqrt { scale } => format!("RSQRT(scale={})", scale), + LookupOp::Exp { scale } => format!("EXP(scale={})", scale), + LookupOp::Tan { scale } => format!("TAN(scale={})", scale), + LookupOp::ATan { scale } => format!("ATAN(scale={})", scale), + LookupOp::Tanh { scale } => format!("TANH(scale={})", scale), + LookupOp::ATanh { scale } => format!("ATANH(scale={})", scale), + LookupOp::Cos { scale } => format!("COS(scale={})", scale), + LookupOp::ACos { scale } => format!("ACOS(scale={})", scale), + LookupOp::Cosh { scale } => format!("COSH(scale={})", scale), + LookupOp::ACosh { scale } => format!("ACOSH(scale={})", scale), + LookupOp::Sin { scale } => format!("SIN(scale={})", scale), + LookupOp::ASin { scale } => format!("ASIN(scale={})", scale), + LookupOp::Sinh { scale } => format!("SINH(scale={})", scale), + LookupOp::ASinh { scale } => format!("ASINH(scale={})", scale), + } + } + + fn layout( + &self, + config: &mut crate::circuit::BaseConfig, + region: &mut RegionCtx, + values: &[ValTensor], + ) -> Result>, Box> { + Ok(Some(layouts::nonlinearity( + config, + region, + values[..].try_into()?, + self, + )?)) + } + + /// Returns the scale of the output of the operation. + fn out_scale(&self, inputs_scale: Vec) -> Result> { + let scale = match self { + LookupOp::Div { denom } => { + let mut scale = inputs_scale[0]; + if scale == 0 { + scale += multiplier_to_scale(1. / denom.0 as f64); + } + scale + } + LookupOp::Recip { scale } => { + let mut out_scale = inputs_scale[0]; + out_scale += + multiplier_to_scale(scale.0 as f64 / scale_to_multiplier(out_scale).powf(2.0)); + out_scale + } + LookupOp::Sign + | LookupOp::GreaterThan { .. } + | LookupOp::LessThan { .. } + | LookupOp::GreaterThanEqual { .. } + | LookupOp::LessThanEqual { .. } + | LookupOp::KroneckerDelta + | LookupOp::Round { .. } + | LookupOp::RoundHalfToEven { .. } + | LookupOp::Ceil { .. } + | LookupOp::Floor { .. } => 0, + _ => inputs_scale[0], + }; + Ok(scale) + } + + fn required_lookups(&self) -> Vec { + vec![self.clone()] + } + + fn clone_dyn(&self) -> Box> { + Box::new(self.clone()) // Forward to the derive(Clone) impl + } +} diff --git a/mnist_ezkl/src/circuit/ops/mod.rs b/mnist_ezkl/src/circuit/ops/mod.rs new file mode 100644 index 0000000..5528677 --- /dev/null +++ b/mnist_ezkl/src/circuit/ops/mod.rs @@ -0,0 +1,346 @@ +use std::{any::Any, error::Error}; + +use serde::{Deserialize, Serialize}; + +use crate::{ + graph::quantize_tensor, + tensor::{self, Tensor, TensorError, TensorType, ValTensor}, +}; +use halo2curves::ff::PrimeField; + +use self::{lookup::LookupOp, region::RegionCtx}; + +/// +pub mod base; +/// +pub mod chip; +/// +pub mod hybrid; +/// Layouts for specific functions (composed of base ops) +pub mod layouts; +/// +pub mod lookup; +/// +pub mod poly; +/// +pub mod region; + +/// A struct representing the result of a forward pass. +#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)] +pub struct ForwardResult { + pub(crate) output: Tensor, + pub(crate) intermediate_lookups: Vec>, +} + +/// A trait representing operations that can be represented as constraints in a circuit. +pub trait Op: std::fmt::Debug + Send + Sync + Any { + /// Matches a [Op] to an operation in the `tensor::ops` module. + fn f(&self, x: &[Tensor]) -> Result, TensorError>; + /// Returns a string representation of the operation. + fn as_string(&self) -> String; + + /// Layouts the operation in a circuit. + fn layout( + &self, + config: &mut crate::circuit::BaseConfig, + region: &mut RegionCtx, + values: &[ValTensor], + ) -> Result>, Box>; + + /// Returns the scale of the output of the operation. + fn out_scale(&self, _: Vec) -> Result>; + + /// Do any of the inputs to this op require homogenous input scales? + fn requires_homogenous_input_scales(&self) -> Vec { + vec![] + } + + /// Returns the lookups required by the operation. + fn required_lookups(&self) -> Vec { + vec![] + } + + /// Returns true if the operation is an input. + fn is_input(&self) -> bool { + false + } + + /// Returns true if the operation is a constant. + fn is_constant(&self) -> bool { + false + } + + /// Boxes and clones + fn clone_dyn(&self) -> Box>; + + /// Returns a reference to the Any trait. + fn as_any(&self) -> &dyn Any; + + /// Safe mode output checl + fn safe_mode_check( + &self, + claimed_output: &ValTensor, + original_values: &[ValTensor], + ) -> Result<(), TensorError> { + let felt_evals = original_values + .iter() + .map(|v| { + let mut evals = v.get_felt_evals().map_err(|_| TensorError::FeltError)?; + evals.reshape(v.dims())?; + Ok(evals) + }) + .collect::, _>>()?; + + let ref_op: Tensor = self.f(&felt_evals)?.output; + + let mut output = claimed_output + .get_felt_evals() + .map_err(|_| TensorError::FeltError)?; + output.reshape(claimed_output.dims())?; + + assert_eq!(output, ref_op); + + Ok(()) + } +} + +impl Clone for Box> { + fn clone(&self) -> Self { + self.clone_dyn() + } +} + +/// +#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)] +pub enum InputType { + /// + Bool, + /// + F16, + /// + F32, + /// + F64, + /// + Int, + /// + TDim, +} + +impl InputType { + /// + pub fn is_integer(&self) -> bool { + matches!(self, InputType::Int | InputType::TDim | InputType::Bool) + } + + /// + pub fn roundtrip(&self, input: &mut T) { + match self { + InputType::Bool => { + let boolean_input = input.clone().to_i64().unwrap(); + assert!(boolean_input == 0 || boolean_input == 1); + *input = T::from_i64(boolean_input).unwrap(); + } + InputType::F16 => { + // TODO: implement f16 + let f32_input = input.clone().to_f32().unwrap(); + *input = T::from_f32(f32_input).unwrap(); + } + InputType::F32 => { + let f32_input = input.clone().to_f32().unwrap(); + *input = T::from_f32(f32_input).unwrap(); + } + InputType::F64 => { + let f64_input = input.clone().to_f64().unwrap(); + *input = T::from_f64(f64_input).unwrap(); + } + InputType::Int | InputType::TDim => { + let int_input = input.clone().to_i128().unwrap(); + *input = T::from_i128(int_input).unwrap(); + } + } + } +} + +/// +#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)] +pub struct Input { + /// + pub scale: crate::Scale, + /// + pub datum_type: InputType, +} + +impl Op for Input { + fn out_scale(&self, _: Vec) -> Result> { + Ok(self.scale) + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn f(&self, x: &[Tensor]) -> Result, TensorError> { + Ok(ForwardResult { + output: x[0].clone(), + intermediate_lookups: vec![], + }) + } + + fn as_string(&self) -> String { + "Input".into() + } + + fn layout( + &self, + config: &mut crate::circuit::BaseConfig, + region: &mut RegionCtx, + values: &[ValTensor], + ) -> Result>, Box> { + let value = values[0].clone(); + if !value.all_prev_assigned() { + match self.datum_type { + InputType::Bool => { + log::debug!("constraining input to be boolean"); + Ok(Some(super::layouts::boolean_identity( + config, + region, + values[..].try_into()?, + )?)) + } + _ => Ok(Some(super::layouts::identity( + config, + region, + values[..].try_into()?, + )?)), + } + } else { + Ok(Some(value)) + } + } + + fn is_input(&self) -> bool { + true + } + + fn clone_dyn(&self) -> Box> { + Box::new(self.clone()) // Forward to the derive(Clone) impl + } +} + +/// An unknown operation. +#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)] +pub struct Unknown; + +impl Op for Unknown { + fn out_scale(&self, _: Vec) -> Result> { + Ok(0) + } + fn as_any(&self) -> &dyn Any { + self + } + fn f(&self, _: &[Tensor]) -> Result, TensorError> { + Err(TensorError::WrongMethod) + } + + fn as_string(&self) -> String { + "Unknown".into() + } + fn layout( + &self, + _: &mut crate::circuit::BaseConfig, + _: &mut RegionCtx, + _: &[ValTensor], + ) -> Result>, Box> { + Err(Box::new(super::CircuitError::UnsupportedOp)) + } + + fn clone_dyn(&self) -> Box> { + Box::new(self.clone()) // Forward to the derive(Clone) impl + } +} + +/// +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct Constant { + /// + pub quantized_values: Tensor, + /// + pub raw_values: Tensor, + /// + #[serde(skip)] + pub pre_assigned_val: Option>, +} + +impl Constant { + /// + pub fn new(quantized_values: Tensor, raw_values: Tensor) -> Self { + Self { + quantized_values, + raw_values, + pre_assigned_val: None, + } + } + /// Rebase the scale of the constant + pub fn rebase_scale(&mut self, new_scale: crate::Scale) -> Result<(), Box> { + let visibility = self.quantized_values.visibility().unwrap(); + self.quantized_values = quantize_tensor(self.raw_values.clone(), new_scale, &visibility)?; + Ok(()) + } + + /// Empty raw value + pub fn empty_raw_value(&mut self) { + self.raw_values = Tensor::new(None, &[0]).unwrap(); + } + + /// + pub fn pre_assign(&mut self, val: ValTensor) { + self.pre_assigned_val = Some(val) + } +} + +impl Deserialize<'de>> Op + for Constant +{ + fn as_any(&self) -> &dyn Any { + self + } + fn f(&self, _: &[Tensor]) -> Result, TensorError> { + let output = self.quantized_values.clone(); + + Ok(ForwardResult { + output, + intermediate_lookups: vec![], + }) + } + + fn as_string(&self) -> String { + format!("CONST (scale={})", self.quantized_values.scale().unwrap()) + } + fn layout( + &self, + config: &mut crate::circuit::BaseConfig, + region: &mut RegionCtx, + _: &[ValTensor], + ) -> Result>, Box> { + let value = if let Some(value) = &self.pre_assigned_val { + value.clone() + } else { + self.quantized_values.clone().try_into()? + }; + // we gotta constrain it once if its used multiple times + Ok(Some(layouts::identity(config, region, &[value])?)) + } + + fn clone_dyn(&self) -> Box> { + Box::new(self.clone()) // Forward to the derive(Clone) impl + } + + fn out_scale(&self, _: Vec) -> Result> { + Ok(self.quantized_values.scale().unwrap()) + } + + fn is_constant(&self) -> bool { + true + } +} diff --git a/mnist_ezkl/src/circuit/ops/poly.rs b/mnist_ezkl/src/circuit/ops/poly.rs new file mode 100644 index 0000000..31bcb7c --- /dev/null +++ b/mnist_ezkl/src/circuit/ops/poly.rs @@ -0,0 +1,427 @@ +use crate::{ + circuit::layouts, + tensor::{self, Tensor, TensorError}, +}; + +use super::{base::BaseOp, *}; + +#[allow(missing_docs)] +/// An enum representing the operations that can be expressed as arithmetic (non lookup) operations. +#[derive(Clone, Debug, Serialize, Deserialize)] +pub enum PolyOp { + Einsum { + equation: String, + }, + Conv { + kernel: Tensor, + bias: Option>, + padding: [(usize, usize); 2], + stride: (usize, usize), + }, + Downsample { + axis: usize, + stride: usize, + modulo: usize, + }, + DeConv { + kernel: Tensor, + bias: Option>, + padding: [(usize, usize); 2], + output_padding: (usize, usize), + stride: (usize, usize), + }, + SumPool { + padding: [(usize, usize); 2], + stride: (usize, usize), + kernel_shape: (usize, usize), + }, + Add, + Sub, + Neg, + Mult, + Identity, + Reshape(Vec), + MoveAxis { + source: usize, + destination: usize, + }, + Flatten(Vec), + Pad([(usize, usize); 2]), + Sum { + axes: Vec, + }, + Prod { + axes: Vec, + len_prod: usize, + }, + Pow(u32), + Pack(u32, u32), + GlobalSumPool, + Concat { + axis: usize, + }, + Slice { + axis: usize, + start: usize, + end: usize, + }, + Iff, + Resize { + scale_factor: Vec, + }, + Not, + And, + Or, + Xor, +} + +impl PolyOp {} + +impl Deserialize<'de>> Op + for PolyOp +{ + /// Returns a reference to the Any trait. + fn as_any(&self) -> &dyn Any { + self + } + + fn as_string(&self) -> String { + match &self { + PolyOp::MoveAxis { .. } => "MOVEAXIS".into(), + PolyOp::Downsample { .. } => "DOWNSAMPLE".into(), + PolyOp::Resize { .. } => "RESIZE".into(), + PolyOp::Iff => "IFF".into(), + PolyOp::Einsum { equation, .. } => format!("EINSUM {}", equation), + PolyOp::Identity => "IDENTITY".into(), + PolyOp::Reshape(shape) => format!("RESHAPE (shape={:?})", shape), + PolyOp::Flatten(_) => "FLATTEN".into(), + PolyOp::Pad(_) => "PAD".into(), + PolyOp::Add => "ADD".into(), + PolyOp::Mult => "MULT".into(), + PolyOp::Sub => "SUB".into(), + PolyOp::Sum { .. } => "SUM".into(), + PolyOp::Prod { .. } => "PROD".into(), + PolyOp::Pow(_) => "POW".into(), + PolyOp::Pack(_, _) => "PACK".into(), + PolyOp::GlobalSumPool => "GLOBALSUMPOOL".into(), + PolyOp::Conv { .. } => "CONV".into(), + PolyOp::DeConv { .. } => "DECONV".into(), + PolyOp::SumPool { .. } => "SUMPOOL".into(), + PolyOp::Concat { axis } => format!("CONCAT (axis={})", axis), + PolyOp::Slice { axis, start, end } => { + format!("SLICE (axis={}, start={}, end={})", axis, start, end) + } + PolyOp::Neg => "NEG".into(), + PolyOp::Not => "NOT".into(), + PolyOp::And => "AND".into(), + PolyOp::Or => "OR".into(), + PolyOp::Xor => "XOR".into(), + } + } + + /// Matches a [Op] to an operation in the `tensor::ops` module. + fn f(&self, inputs: &[Tensor]) -> Result, TensorError> { + let mut inputs = inputs.to_vec(); + let res = match &self { + PolyOp::And => tensor::ops::and(&inputs[0], &inputs[1]), + PolyOp::Or => tensor::ops::or(&inputs[0], &inputs[1]), + PolyOp::Xor => tensor::ops::xor(&inputs[0], &inputs[1]), + PolyOp::Not => tensor::ops::not(&inputs[0]), + PolyOp::Downsample { + axis, + stride, + modulo, + } => tensor::ops::downsample(&inputs[0], *axis, *stride, *modulo), + PolyOp::Resize { scale_factor } => tensor::ops::resize(&inputs[0], scale_factor), + PolyOp::Iff => tensor::ops::iff(&inputs[0], &inputs[1], &inputs[2]), + PolyOp::Einsum { equation } => tensor::ops::einsum(equation, &inputs), + PolyOp::Identity => Ok(inputs[0].clone()), + PolyOp::Reshape(new_dims) => { + let mut t = inputs[0].clone(); + t.reshape(new_dims)?; + Ok(t) + } + PolyOp::MoveAxis { + source, + destination, + } => inputs[0].move_axis(*source, *destination), + PolyOp::Flatten(new_dims) => { + let mut t = inputs[0].clone(); + t.reshape(new_dims)?; + Ok(t) + } + PolyOp::Pad(p) => { + if 1 != inputs.len() { + return Err(TensorError::DimMismatch("pad inputs".to_string())); + } + tensor::ops::pad(&inputs[0], *p) + } + PolyOp::Add => tensor::ops::add(&inputs), + PolyOp::Neg => tensor::ops::neg(&inputs[0]), + PolyOp::Sub => tensor::ops::sub(&inputs), + PolyOp::Mult => tensor::ops::mult(&inputs), + PolyOp::Conv { + kernel: a, + bias, + padding, + stride, + } => { + inputs.push(a.clone()); + if let Some(b) = bias { + inputs.push(b.clone()); + } + tensor::ops::conv(&inputs, *padding, *stride) + } + PolyOp::DeConv { + kernel: a, + bias, + padding, + output_padding, + stride, + } => { + inputs.push(a.clone()); + if let Some(b) = bias { + inputs.push(b.clone()); + } + tensor::ops::deconv(&inputs, *padding, *output_padding, *stride) + } + PolyOp::SumPool { + padding, + stride, + kernel_shape, + } => tensor::ops::sumpool(&inputs[0], *padding, *stride, *kernel_shape), + PolyOp::Pack(base, scale) => { + if 1 != inputs.len() { + return Err(TensorError::DimMismatch("pack inputs".to_string())); + } + + tensor::ops::pack(&inputs[0], F::from(*base as u64), *scale) + } + PolyOp::Pow(u) => { + if 1 != inputs.len() { + return Err(TensorError::DimMismatch("pow inputs".to_string())); + } + inputs[0].pow(*u) + } + PolyOp::Sum { axes } => { + if 1 != inputs.len() { + return Err(TensorError::DimMismatch("sum inputs".to_string())); + } + tensor::ops::sum_axes(&inputs[0], axes) + } + PolyOp::Prod { axes, .. } => { + if 1 != inputs.len() { + return Err(TensorError::DimMismatch("prod inputs".to_string())); + } + tensor::ops::prod_axes(&inputs[0], axes) + } + PolyOp::GlobalSumPool => unreachable!(), + PolyOp::Concat { axis } => { + tensor::ops::concat(&inputs.iter().collect::>(), *axis) + } + PolyOp::Slice { axis, start, end } => { + if 1 != inputs.len() { + return Err(TensorError::DimMismatch("slice inputs".to_string())); + } + Ok(tensor::ops::slice(&inputs[0], axis, start, end)?) + } + }?; + + Ok(ForwardResult { + output: res, + intermediate_lookups: vec![], + }) + } + + fn layout( + &self, + config: &mut crate::circuit::BaseConfig, + region: &mut RegionCtx, + values: &[ValTensor], + ) -> Result>, Box> { + let mut values = values.to_vec(); + + Ok(Some(match self { + PolyOp::Xor => layouts::xor(config, region, values[..].try_into()?)?, + PolyOp::Or => layouts::or(config, region, values[..].try_into()?)?, + PolyOp::And => layouts::and(config, region, values[..].try_into()?)?, + PolyOp::Not => layouts::not(config, region, values[..].try_into()?)?, + PolyOp::MoveAxis { + source, + destination, + } => layouts::move_axis(values[..].try_into()?, *source, *destination)?, + PolyOp::Downsample { + axis, + stride, + modulo, + } => layouts::downsample(config, region, values[..].try_into()?, axis, stride, modulo)?, + PolyOp::Resize { scale_factor } => { + layouts::resize(config, region, values[..].try_into()?, scale_factor)? + } + PolyOp::Neg => layouts::neg(config, region, values[..].try_into()?)?, + PolyOp::Iff => layouts::iff(config, region, values[..].try_into()?)?, + PolyOp::Einsum { equation } => layouts::einsum(config, region, &values, equation)?, + PolyOp::Sum { axes } => { + layouts::sum_axes(config, region, values[..].try_into()?, axes)? + } + PolyOp::Prod { axes, .. } => { + layouts::prod_axes(config, region, values[..].try_into()?, axes)? + } + PolyOp::Conv { + kernel, + bias, + padding, + stride, + } => { + values.push(kernel.clone().try_into()?); + if let Some(bias) = bias { + values.push(bias.clone().try_into()?); + } + layouts::conv(config, region, values[..].try_into()?, *padding, *stride)? + } + PolyOp::DeConv { + kernel, + bias, + padding, + output_padding, + stride, + } => { + values.push(kernel.clone().try_into()?); + if let Some(bias) = bias { + values.push(bias.clone().try_into()?); + } + layouts::deconv( + config, + region, + values[..].try_into()?, + *padding, + *output_padding, + *stride, + )? + } + PolyOp::SumPool { + padding, + stride, + kernel_shape, + } => layouts::sumpool( + config, + region, + values[..].try_into()?, + *padding, + *stride, + *kernel_shape, + )?, + PolyOp::Add => layouts::pairwise(config, region, values[..].try_into()?, BaseOp::Add)?, + PolyOp::Sub => layouts::pairwise(config, region, values[..].try_into()?, BaseOp::Sub)?, + PolyOp::Mult => { + layouts::pairwise(config, region, values[..].try_into()?, BaseOp::Mult)? + } + PolyOp::Identity => layouts::identity(config, region, values[..].try_into()?)?, + PolyOp::Reshape(d) | PolyOp::Flatten(d) => layouts::reshape(values[..].try_into()?, d)?, + PolyOp::Pad(p) => { + if values.len() != 1 { + return Err(Box::new(TensorError::DimError)); + } + let mut input = values[0].clone(); + input.pad(*p)?; + input + } + PolyOp::Pow(exp) => layouts::pow(config, region, values[..].try_into()?, *exp)?, + PolyOp::Pack(base, scale) => { + layouts::pack(config, region, values[..].try_into()?, *base, *scale)? + } + PolyOp::GlobalSumPool => unreachable!(), + PolyOp::Concat { axis } => layouts::concat(values[..].try_into()?, axis)?, + PolyOp::Slice { axis, start, end } => { + layouts::slice(config, region, values[..].try_into()?, axis, start, end)? + } + })) + } + + fn out_scale(&self, in_scales: Vec) -> Result> { + let scale = match self { + PolyOp::Xor | PolyOp::Or | PolyOp::And | PolyOp::Not => 0, + PolyOp::Neg => in_scales[0], + PolyOp::MoveAxis { .. } => in_scales[0], + PolyOp::Downsample { .. } => in_scales[0], + PolyOp::Resize { .. } => in_scales[0], + PolyOp::Iff => in_scales[1], + PolyOp::Einsum { .. } => { + let mut scale = in_scales[0]; + for s in in_scales.iter().skip(1) { + scale += *s; + } + scale + } + PolyOp::Prod { len_prod, .. } => in_scales[0] * (*len_prod as crate::Scale), + PolyOp::Sum { .. } => in_scales[0], + PolyOp::Conv { kernel, bias, .. } => { + let kernel_scale = match kernel.scale() { + Some(s) => s, + None => return Err("scale must be set for conv kernel".into()), + }; + let output_scale = in_scales[0] + kernel_scale; + if let Some(b) = bias { + let bias_scale = match b.scale() { + Some(s) => s, + None => return Err("scale must be set for conv bias".into()), + }; + assert_eq!(output_scale, bias_scale); + } + output_scale + } + PolyOp::DeConv { kernel, bias, .. } => { + let kernel_scale = match kernel.scale() { + Some(s) => s, + None => return Err("scale must be set for deconv kernel".into()), + }; + let output_scale = in_scales[0] + kernel_scale; + if let Some(b) = bias { + let bias_scale = match b.scale() { + Some(s) => s, + None => return Err("scale must be set for deconv bias".into()), + }; + assert_eq!(output_scale, bias_scale); + } + output_scale + } + PolyOp::SumPool { .. } => in_scales[0], + PolyOp::Add => { + let mut scale_a = 0; + let scale_b = in_scales[0]; + scale_a += in_scales[1]; + assert_eq!(scale_a, scale_b); + scale_a + } + PolyOp::Sub => in_scales[0], + PolyOp::Mult => { + let mut scale = in_scales[0]; + scale += in_scales[1]; + scale + } + PolyOp::Identity => in_scales[0], + PolyOp::Reshape(_) | PolyOp::Flatten(_) => in_scales[0], + PolyOp::Pad(_) => in_scales[0], + PolyOp::Pow(pow) => in_scales[0] * (*pow as crate::Scale), + PolyOp::Pack(_, _) => in_scales[0], + PolyOp::GlobalSumPool => in_scales[0], + PolyOp::Concat { axis: _ } => in_scales[0], + PolyOp::Slice { .. } => in_scales[0], + }; + Ok(scale) + } + + fn requires_homogenous_input_scales(&self) -> Vec { + if matches!(self, PolyOp::Add { .. } | PolyOp::Sub) { + vec![0, 1] + } else if matches!(self, PolyOp::Iff) { + vec![1, 2] + } else { + vec![] + } + } + + fn clone_dyn(&self) -> Box> { + Box::new(self.clone()) // Forward to the derive(Clone) impl + } +} diff --git a/mnist_ezkl/src/circuit/ops/region.rs b/mnist_ezkl/src/circuit/ops/region.rs new file mode 100644 index 0000000..4a1aa8d --- /dev/null +++ b/mnist_ezkl/src/circuit/ops/region.rs @@ -0,0 +1,388 @@ +use crate::tensor::{Tensor, TensorError, TensorType, ValTensor, ValType, VarTensor}; +use halo2_proofs::{ + circuit::Region, + plonk::{Error, Selector}, +}; +use halo2curves::ff::PrimeField; +use std::{ + cell::RefCell, + collections::HashSet, + sync::atomic::{AtomicUsize, Ordering}, +}; + +/// Region error +#[derive(Debug, thiserror::Error)] +pub enum RegionError { + /// wrap other regions + #[error("Wrapped region: {0}")] + Wrapped(String), +} + +impl From for RegionError { + fn from(e: String) -> Self { + Self::Wrapped(e) + } +} + +impl From<&str> for RegionError { + fn from(e: &str) -> Self { + Self::Wrapped(e.to_string()) + } +} + +impl From for RegionError { + fn from(e: TensorError) -> Self { + Self::Wrapped(format!("{:?}", e)) + } +} + +impl From for RegionError { + fn from(e: Error) -> Self { + Self::Wrapped(format!("{:?}", e)) + } +} + +impl From> for RegionError { + fn from(e: Box) -> Self { + Self::Wrapped(format!("{:?}", e)) + } +} + +#[derive(Debug)] +/// A context for a region +pub struct RegionCtx<'a, F: PrimeField + TensorType + PartialOrd> { + region: Option>>, + row: usize, + linear_coord: usize, + num_inner_cols: usize, + total_constants: usize, +} + +impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> { + /// + pub fn increment_total_constants(&mut self, n: usize) { + self.total_constants += n; + } + + /// Create a new region context + pub fn new(region: Region<'a, F>, row: usize, num_inner_cols: usize) -> RegionCtx<'a, F> { + let region = Some(RefCell::new(region)); + let linear_coord = row * num_inner_cols; + + RegionCtx { + region, + num_inner_cols, + row, + linear_coord, + total_constants: 0, + } + } + /// Create a new region context from a wrapped region + pub fn from_wrapped_region( + region: Option>>, + row: usize, + num_inner_cols: usize, + ) -> RegionCtx<'a, F> { + let linear_coord = row * num_inner_cols; + RegionCtx { + region, + num_inner_cols, + linear_coord, + row, + total_constants: 0, + } + } + + /// Create a new region context + pub fn new_dummy(row: usize, num_inner_cols: usize) -> RegionCtx<'a, F> { + let region = None; + let linear_coord = row * num_inner_cols; + + RegionCtx { + region, + num_inner_cols, + linear_coord, + row, + total_constants: 0, + } + } + + /// Create a new region context + pub fn new_dummy_with_constants( + row: usize, + linear_coord: usize, + constants: usize, + num_inner_cols: usize, + ) -> RegionCtx<'a, F> { + let region = None; + RegionCtx { + region, + num_inner_cols, + linear_coord, + row, + total_constants: constants, + } + } + + /// Apply a function in a loop to the region + pub fn apply_in_loop( + &mut self, + output: &mut Tensor, + inner_loop_function: impl Fn(usize, &mut RegionCtx<'a, F>) -> Result + + Send + + Sync, + ) -> Result<(), RegionError> { + if self.is_dummy() { + self.dummy_loop(output, inner_loop_function)?; + } else { + self.real_loop(output, inner_loop_function)?; + } + Ok(()) + } + + /// Run a loop + pub fn real_loop( + &mut self, + output: &mut Tensor, + inner_loop_function: impl Fn(usize, &mut RegionCtx<'a, F>) -> Result, + ) -> Result<(), RegionError> { + output + .iter_mut() + .enumerate() + .map(|(i, o)| { + *o = inner_loop_function(i, self)?; + Ok(()) + }) + .collect::, RegionError>>()?; + + Ok(()) + } + + /// Create a new region context per loop iteration + /// hacky but it works + pub fn dummy_loop( + &mut self, + output: &mut Tensor, + inner_loop_function: impl Fn(usize, &mut RegionCtx<'a, F>) -> Result + + Send + + Sync, + ) -> Result<(), RegionError> { + let row = AtomicUsize::new(self.row()); + let linear_coord = AtomicUsize::new(self.linear_coord()); + let constants = AtomicUsize::new(self.total_constants()); + + *output = output + .par_enum_map(|idx, _| { + // we kick off the loop with the current offset + let starting_offset = row.load(Ordering::SeqCst); + let starting_linear_coord = linear_coord.load(Ordering::SeqCst); + let starting_constants = constants.load(Ordering::SeqCst); + // we need to make sure that the region is not shared between threads + let mut local_reg = Self::new_dummy_with_constants( + starting_offset, + starting_linear_coord, + starting_constants, + self.num_inner_cols, + ); + let res = inner_loop_function(idx, &mut local_reg); + // we update the offset and constants + row.fetch_add(local_reg.row() - starting_offset, Ordering::SeqCst); + linear_coord.fetch_add( + local_reg.linear_coord() - starting_linear_coord, + Ordering::SeqCst, + ); + constants.fetch_add( + local_reg.total_constants() - starting_constants, + Ordering::SeqCst, + ); + res + }) + .map_err(|e| { + log::error!("dummy_loop: {:?}", e); + Error::Synthesis + })?; + self.total_constants = constants.into_inner(); + self.linear_coord = linear_coord.into_inner(); + self.row = row.into_inner(); + Ok(()) + } + + /// Check if the region is dummy + pub fn is_dummy(&self) -> bool { + self.region.is_none() + } + + /// duplicate_dummy + pub fn duplicate_dummy(&self) -> Self { + Self { + region: None, + linear_coord: self.linear_coord, + num_inner_cols: self.num_inner_cols, + row: self.row, + total_constants: self.total_constants, + } + } + + /// Get the offset + pub fn row(&self) -> usize { + self.row + } + + /// Linear coordinate + pub fn linear_coord(&self) -> usize { + self.linear_coord + } + + /// Get the total number of constants + pub fn total_constants(&self) -> usize { + self.total_constants + } + + /// Assign a constant value + pub fn assign_constant(&mut self, var: &VarTensor, value: F) -> Result, Error> { + self.total_constants += 1; + if let Some(region) = &self.region { + let cell = var.assign_constant(&mut region.borrow_mut(), self.linear_coord, value)?; + Ok(cell.into()) + } else { + Ok(value.into()) + } + } + /// Assign a valtensor to a vartensor + pub fn assign( + &mut self, + var: &VarTensor, + values: &ValTensor, + ) -> Result, Error> { + self.total_constants += values.num_constants(); + if let Some(region) = &self.region { + var.assign(&mut region.borrow_mut(), self.linear_coord, values) + } else { + Ok(values.clone()) + } + } + + /// Assign a valtensor to a vartensor + pub fn assign_with_omissions( + &mut self, + var: &VarTensor, + values: &ValTensor, + ommissions: &HashSet<&usize>, + ) -> Result, Error> { + if let Some(region) = &self.region { + var.assign_with_omissions( + &mut region.borrow_mut(), + self.linear_coord, + values, + ommissions, + ) + } else { + self.total_constants += values.num_constants(); + let inner_tensor = values.get_inner_tensor().unwrap(); + for o in ommissions { + self.total_constants -= inner_tensor.get_flat_index(**o).is_constant() as usize; + } + Ok(values.clone()) + } + } + + /// Assign a valtensor to a vartensor with duplication + pub fn assign_with_duplication( + &mut self, + var: &VarTensor, + values: &ValTensor, + check_mode: &crate::circuit::CheckMode, + single_inner_col: bool, + ) -> Result<(ValTensor, usize), Error> { + if let Some(region) = &self.region { + // duplicates every nth element to adjust for column overflow + let (res, len, total_assigned_constants) = var.assign_with_duplication( + &mut region.borrow_mut(), + self.row, + self.linear_coord, + values, + check_mode, + single_inner_col, + )?; + self.total_constants += total_assigned_constants; + Ok((res, len)) + } else { + let (_, len, total_assigned_constants) = var.dummy_assign_with_duplication( + self.row, + self.linear_coord, + values, + single_inner_col, + )?; + self.total_constants += total_assigned_constants; + Ok((values.clone(), len)) + } + } + + /// Enable a selector + pub fn enable(&mut self, selector: Option<&Selector>, offset: usize) -> Result<(), Error> { + match &self.region { + Some(region) => selector.unwrap().enable(&mut region.borrow_mut(), offset), + None => Ok(()), + } + } + + /// constrain equal + pub fn constrain_equal(&mut self, a: &ValTensor, b: &ValTensor) -> Result<(), Error> { + if let Some(region) = &self.region { + let a = a.get_inner_tensor().unwrap(); + let b = b.get_inner_tensor().unwrap(); + assert_eq!(a.len(), b.len()); + a.iter().zip(b.iter()).try_for_each(|(a, b)| { + let a = a.get_prev_assigned(); + let b = b.get_prev_assigned(); + // if they're both assigned, we can constrain them + if let (Some(a), Some(b)) = (&a, &b) { + region.borrow_mut().constrain_equal(a.cell(), b.cell()) + } else if a.is_some() || b.is_some() { + log::error!( + "constrain_equal: one of the tensors is assigned and the other is not" + ); + return Err(Error::Synthesis); + } else { + Ok(()) + } + }) + } else { + Ok(()) + } + } + + /// Increment the offset by 1 + pub fn next(&mut self) { + self.linear_coord += 1; + if self.linear_coord % self.num_inner_cols == 0 { + self.row += 1; + } + } + + /// Increment the offset + pub fn increment(&mut self, n: usize) { + for _ in 0..n { + self.next() + } + } + + /// flush row to the next row + pub fn flush(&mut self) -> Result<(), Box> { + // increment by the difference between the current linear coord and the next row + let remainder = self.linear_coord % self.num_inner_cols; + if remainder != 0 { + let diff = self.num_inner_cols - remainder; + self.increment(diff); + } + if !(self.linear_coord % self.num_inner_cols == 0) { + return Err("flush: linear coord is not aligned with the next row".into()); + } + Ok(()) + } + + /// increment constants + pub fn increment_constants(&mut self, n: usize) { + self.total_constants += n + } +} diff --git a/mnist_ezkl/src/circuit/table.rs b/mnist_ezkl/src/circuit/table.rs new file mode 100644 index 0000000..6527fc5 --- /dev/null +++ b/mnist_ezkl/src/circuit/table.rs @@ -0,0 +1,259 @@ +use std::{error::Error, marker::PhantomData}; + +use halo2curves::ff::PrimeField; + +use halo2_proofs::{ + circuit::{Layouter, Value}, + plonk::{ConstraintSystem, Expression, TableColumn}, +}; +use log::warn; +use rayon::prelude::{IntoParallelIterator, ParallelIterator}; + +use crate::{ + circuit::CircuitError, + fieldutils::i128_to_felt, + tensor::{Tensor, TensorType}, +}; + +use crate::circuit::lookup::LookupOp; + +use super::Op; + +/// The safety factor for the range of the lookup table. +pub const RANGE_MULTIPLIER: i128 = 2; +/// The safety factor offset for the number of rows in the lookup table. +pub const RESERVED_BLINDING_ROWS_PAD: usize = 3; + +#[derive(Debug, Clone)] +/// +pub struct SelectorConstructor { + /// + pub degree: usize, + /// + _marker: PhantomData, +} + +impl SelectorConstructor { + /// + pub fn new(degree: usize) -> Self { + Self { + degree, + _marker: PhantomData, + } + } + + /// + pub fn get_expr_at_idx(&self, i: usize, expr: Expression) -> Expression { + let indices = 0..self.degree; + indices + .into_par_iter() + .filter(|x| *x != i) + .map(|i| { + if i == 0 { + expr.clone() + } else { + (Expression::Constant(F::from(i as u64))) - expr.clone() + } + }) + .reduce(|| Expression::Constant(F::from(1_u64)), |acc, x| acc * x) + } + + /// + pub fn get_selector_val_at_idx(&self, i: usize) -> F { + let indices = 0..self.degree; + indices + .into_par_iter() + .filter(|x| *x != i) + .map(|x| { + if x == 0 { + F::from(i as u64) + } else { + F::from(x as u64) - F::from(i as u64) + } + }) + .product() + } +} + +/// Halo2 lookup table for element wise non-linearities. +#[derive(Clone, Debug)] +pub struct Table { + /// Non-linearity to be used in table. + pub nonlinearity: LookupOp, + /// Input to table. + pub table_inputs: Vec, + /// col size + pub col_size: usize, + /// Output of table + pub table_outputs: Vec, + /// selector cn + pub selector_constructor: SelectorConstructor, + /// Flags if table has been previously assigned to. + pub is_assigned: bool, + /// Number of bits used in lookup table. + pub range: (i128, i128), + _marker: PhantomData, +} + +impl Table { + /// get column index given input + pub fn get_col_index(&self, input: F) -> F { + // range is split up into chunks of size col_size, find the chunk that input is in + let chunk = + (crate::fieldutils::felt_to_i128(input) - self.range.0).abs() / (self.col_size as i128); + + i128_to_felt(chunk) + } + + /// get first_element of column + pub fn get_first_element(&self, chunk: usize) -> (F, F) { + let chunk = chunk as i128; + // we index from 1 to prevent soundness issues + let first_element = i128_to_felt(chunk * (self.col_size as i128) + self.range.0); + let op_f = Op::::f( + &self.nonlinearity, + &[Tensor::from(vec![first_element].into_iter())], + ) + .unwrap(); + (first_element, op_f.output[0]) + } + + /// + pub fn cal_col_size(logrows: usize, reserved_blinding_rows: usize) -> usize { + 2usize.pow(logrows as u32) - reserved_blinding_rows + } + + /// + pub fn cal_bit_range(bits: usize, reserved_blinding_rows: usize) -> usize { + 2usize.pow(bits as u32) - reserved_blinding_rows + } + + /// + pub fn num_cols_required(range: (i128, i128), col_size: usize) -> usize { + // double it to be safe + let range_len = range.1 - range.0; + // number of cols needed to store the range + (range_len / (col_size as i128)) as usize + 1 + } +} + +impl Table { + /// Configures the table. + pub fn configure( + cs: &mut ConstraintSystem, + range: (i128, i128), + logrows: usize, + nonlinearity: &LookupOp, + preexisting_inputs: Option>, + ) -> Table { + let factors = cs.blinding_factors() + RESERVED_BLINDING_ROWS_PAD; + let col_size = Self::cal_col_size(logrows, factors); + // number of cols needed to store the range + let num_cols = Self::num_cols_required(range, col_size); + + log::debug!("table range: {:?}", range); + + let table_inputs = preexisting_inputs.unwrap_or_else(|| { + let mut cols = vec![]; + for _ in 0..num_cols { + cols.push(cs.lookup_table_column()); + } + cols + }); + + let num_cols = table_inputs.len(); + + if num_cols > 1 { + warn!("Using {} columns for non-linearity table.", num_cols); + } + + let table_outputs = table_inputs + .iter() + .map(|_| cs.lookup_table_column()) + .collect::>(); + + Table { + nonlinearity: nonlinearity.clone(), + table_inputs, + table_outputs, + is_assigned: false, + selector_constructor: SelectorConstructor::new(num_cols), + col_size, + range, + _marker: PhantomData, + } + } + + /// Take a linear coordinate and output the (column, row) position in the storage block. + pub fn cartesian_coord(&self, linear_coord: usize) -> (usize, usize) { + let x = linear_coord / self.col_size; + let y = linear_coord % self.col_size; + (x, y) + } + + /// Assigns values to the constraints generated when calling `configure`. + pub fn layout( + &mut self, + layouter: &mut impl Layouter, + preassigned_input: bool, + ) -> Result<(), Box> { + if self.is_assigned { + return Err(Box::new(CircuitError::TableAlreadyAssigned)); + } + + let smallest = self.range.0; + let largest = self.range.1; + + let inputs = Tensor::from(smallest..=largest).map(|x| i128_to_felt(x)); + let evals = Op::::f(&self.nonlinearity, &[inputs.clone()])?; + let chunked_inputs = inputs.chunks(self.col_size); + + self.is_assigned = true; + + let col_multipliers: Vec = (0..chunked_inputs.len()) + .map(|x| self.selector_constructor.get_selector_val_at_idx(x)) + .collect(); + + let _ = chunked_inputs + .enumerate() + .map(|(chunk_idx, inputs)| { + layouter.assign_table( + || "nl table", + |mut table| { + let _ = inputs + .iter() + .enumerate() + .map(|(mut row_offset, input)| { + let col_multiplier = col_multipliers[chunk_idx]; + + row_offset += chunk_idx * self.col_size; + let (x, y) = self.cartesian_coord(row_offset); + if !preassigned_input { + table.assign_cell( + || format!("nl_i_col row {}", row_offset), + self.table_inputs[x], + y, + || Value::known(*input * col_multiplier), + )?; + } + + let output = evals.output[row_offset]; + + table.assign_cell( + || format!("nl_o_col row {}", row_offset), + self.table_outputs[x], + y, + || Value::known(output * col_multiplier), + )?; + + Ok(()) + }) + .collect::, halo2_proofs::plonk::Error>>()?; + Ok(()) + }, + ) + }) + .collect::, halo2_proofs::plonk::Error>>()?; + Ok(()) + } +} diff --git a/mnist_ezkl/src/circuit/tests.rs b/mnist_ezkl/src/circuit/tests.rs new file mode 100644 index 0000000..1a68505 --- /dev/null +++ b/mnist_ezkl/src/circuit/tests.rs @@ -0,0 +1,2642 @@ +use crate::circuit::ops::hybrid::HybridOp; +use crate::circuit::ops::poly::PolyOp; +use crate::circuit::*; +use crate::tensor::{Tensor, TensorType, ValTensor, VarTensor}; +use halo2_proofs::{ + circuit::{Layouter, SimpleFloorPlanner, Value}, + dev::MockProver, + plonk::{Circuit, ConstraintSystem, Error}, +}; +use halo2curves::bn256::Fr as F; +use halo2curves::ff::{Field, PrimeField}; +use ops::lookup::LookupOp; +use ops::region::RegionCtx; +use rand::rngs::OsRng; +use std::marker::PhantomData; + +#[derive(Default)] +struct TestParams; + +#[cfg(test)] +mod matmul { + + use super::*; + + const K: usize = 9; + const LEN: usize = 3; + + #[derive(Clone)] + struct MatmulCircuit { + inputs: [ValTensor; 2], + _marker: PhantomData, + } + + impl Circuit for MatmulCircuit { + type Config = BaseConfig; + type FloorPlanner = SimpleFloorPlanner; + type Params = TestParams; + + fn without_witnesses(&self) -> Self { + self.clone() + } + + fn configure(cs: &mut ConstraintSystem) -> Self::Config { + let a = VarTensor::new_advice(cs, K, 1, LEN * LEN); + let b = VarTensor::new_advice(cs, K, 1, LEN * LEN); + let output = VarTensor::new_advice(cs, K, 1, LEN * LEN); + Self::Config::configure(cs, &[a, b], &output, CheckMode::SAFE) + } + + fn synthesize( + &self, + mut config: Self::Config, + mut layouter: impl Layouter, + ) -> Result<(), Error> { + layouter + .assign_region( + || "", + |region| { + let mut region = RegionCtx::new(region, 0, 1); + config + .layout( + &mut region, + &self.inputs.clone(), + Box::new(PolyOp::Einsum { + equation: "ij,jk->ik".to_string(), + }), + ) + .map_err(|_| Error::Synthesis) + }, + ) + .unwrap(); + + Ok(()) + } + } + + #[test] + fn matmulcircuit() { + // parameters + let mut a = + Tensor::from((0..(LEN + 1) * LEN).map(|i| Value::known(F::from((i + 1) as u64)))); + a.reshape(&[LEN, LEN + 1]).unwrap(); + + let mut w = Tensor::from((0..LEN + 1).map(|i| Value::known(F::from((i + 1) as u64)))); + w.reshape(&[LEN + 1, 1]).unwrap(); + + let circuit = MatmulCircuit:: { + inputs: [ValTensor::from(a), ValTensor::from(w)], + _marker: PhantomData, + }; + + let prover = MockProver::run(K as u32, &circuit, vec![]).unwrap(); + prover.assert_satisfied_par(); + } +} + +#[cfg(test)] +mod matmul_col_overflow_double_col { + use super::*; + + const K: usize = 5; + const LEN: usize = 6; + const NUM_INNER_COLS: usize = 2; + + #[derive(Clone)] + struct MatmulCircuit { + inputs: [ValTensor; 2], + _marker: PhantomData, + } + + impl Circuit for MatmulCircuit { + type Config = BaseConfig; + type FloorPlanner = SimpleFloorPlanner; + type Params = TestParams; + + fn without_witnesses(&self) -> Self { + self.clone() + } + + fn configure(cs: &mut ConstraintSystem) -> Self::Config { + let a = VarTensor::new_advice(cs, K, NUM_INNER_COLS, LEN * LEN * LEN); + let b = VarTensor::new_advice(cs, K, NUM_INNER_COLS, LEN * LEN * LEN); + let output = VarTensor::new_advice(cs, K, NUM_INNER_COLS, LEN * LEN * LEN); + Self::Config::configure(cs, &[a, b], &output, CheckMode::SAFE) + } + + fn synthesize( + &self, + mut config: Self::Config, + mut layouter: impl Layouter, + ) -> Result<(), Error> { + layouter + .assign_region( + || "", + |region| { + let mut region = RegionCtx::new(region, 0, NUM_INNER_COLS); + config + .layout( + &mut region, + &self.inputs.clone(), + Box::new(PolyOp::Einsum { + equation: "ij,jk->ik".to_string(), + }), + ) + .map_err(|_| Error::Synthesis) + }, + ) + .unwrap(); + Ok(()) + } + } + + #[test] + fn matmulcircuit() { + // parameters + let mut a = Tensor::from((0..LEN * LEN).map(|i| Value::known(F::from((i + 1) as u64)))); + a.reshape(&[LEN, LEN]).unwrap(); + + let mut w = Tensor::from((0..LEN).map(|i| Value::known(F::from((i + 1) as u64)))); + w.reshape(&[LEN, 1]).unwrap(); + + let circuit = MatmulCircuit:: { + inputs: [ValTensor::from(a), ValTensor::from(w)], + _marker: PhantomData, + }; + + let prover = MockProver::run(K as u32, &circuit, vec![]).unwrap(); + prover.assert_satisfied_par(); + } +} + +#[cfg(test)] +mod matmul_col_overflow { + use super::*; + + const K: usize = 5; + const LEN: usize = 6; + + #[derive(Clone)] + struct MatmulCircuit { + inputs: [ValTensor; 2], + _marker: PhantomData, + } + + impl Circuit for MatmulCircuit { + type Config = BaseConfig; + type FloorPlanner = SimpleFloorPlanner; + type Params = TestParams; + + fn without_witnesses(&self) -> Self { + self.clone() + } + + fn configure(cs: &mut ConstraintSystem) -> Self::Config { + let a = VarTensor::new_advice(cs, K, 1, LEN * LEN * LEN); + let b = VarTensor::new_advice(cs, K, 1, LEN * LEN * LEN); + let output = VarTensor::new_advice(cs, K, 1, LEN * LEN * LEN); + Self::Config::configure(cs, &[a, b], &output, CheckMode::SAFE) + } + + fn synthesize( + &self, + mut config: Self::Config, + mut layouter: impl Layouter, + ) -> Result<(), Error> { + layouter + .assign_region( + || "", + |region| { + let mut region = RegionCtx::new(region, 0, 1); + config + .layout( + &mut region, + &self.inputs.clone(), + Box::new(PolyOp::Einsum { + equation: "ij,jk->ik".to_string(), + }), + ) + .map_err(|_| Error::Synthesis) + }, + ) + .unwrap(); + Ok(()) + } + } + + #[test] + fn matmulcircuit() { + // parameters + let mut a = Tensor::from((0..LEN * LEN).map(|i| Value::known(F::from((i + 1) as u64)))); + a.reshape(&[LEN, LEN]).unwrap(); + + let mut w = Tensor::from((0..LEN).map(|i| Value::known(F::from((i + 1) as u64)))); + w.reshape(&[LEN, 1]).unwrap(); + + let circuit = MatmulCircuit:: { + inputs: [ValTensor::from(a), ValTensor::from(w)], + _marker: PhantomData, + }; + + let prover = MockProver::run(K as u32, &circuit, vec![]).unwrap(); + prover.assert_satisfied_par(); + } +} + +#[cfg(test)] +#[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))] +mod matmul_col_ultra_overflow_double_col { + use halo2_proofs::poly::commitment::ParamsProver; + + use super::*; + + const K: usize = 4; + const LEN: usize = 20; + const NUM_INNER_COLS: usize = 2; + + #[derive(Clone)] + struct MatmulCircuit { + inputs: [ValTensor; 2], + _marker: PhantomData, + } + + impl Circuit for MatmulCircuit { + type Config = BaseConfig; + type FloorPlanner = SimpleFloorPlanner; + type Params = TestParams; + + fn without_witnesses(&self) -> Self { + self.clone() + } + + fn configure(cs: &mut ConstraintSystem) -> Self::Config { + let a = VarTensor::new_advice(cs, K, NUM_INNER_COLS, LEN * LEN * LEN); + let b = VarTensor::new_advice(cs, K, NUM_INNER_COLS, LEN * LEN * LEN); + let output = VarTensor::new_advice(cs, K, NUM_INNER_COLS, LEN * LEN * LEN); + Self::Config::configure(cs, &[a, b], &output, CheckMode::SAFE) + } + + fn synthesize( + &self, + mut config: Self::Config, + mut layouter: impl Layouter, + ) -> Result<(), Error> { + layouter + .assign_region( + || "", + |region| { + let mut region = RegionCtx::new(region, 0, NUM_INNER_COLS); + config + .layout( + &mut region, + &self.inputs.clone(), + Box::new(PolyOp::Einsum { + equation: "ij,jk->ik".to_string(), + }), + ) + .map_err(|_| Error::Synthesis) + }, + ) + .unwrap(); + Ok(()) + } + } + + #[test] + #[ignore] + fn matmulcircuit() { + // get some logs fam + crate::logger::init_logger(); + // parameters + let mut a = Tensor::from((0..LEN * LEN).map(|i| Value::known(F::from((i + 1) as u64)))); + a.reshape(&[LEN, LEN]).unwrap(); + + let mut w = Tensor::from((0..LEN).map(|i| Value::known(F::from((i + 1) as u64)))); + w.reshape(&[LEN, 1]).unwrap(); + + let circuit = MatmulCircuit:: { + inputs: [ValTensor::from(a), ValTensor::from(w)], + _marker: PhantomData, + }; + + let params = crate::pfsys::srs::gen_srs::< + halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme<_>, + >(K as u32); + + let pk = crate::pfsys::create_keys::< + halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme, + F, + MatmulCircuit, + >(&circuit, ¶ms) + .unwrap(); + + let prover = crate::pfsys::create_proof_circuit_kzg( + circuit.clone(), + ¶ms, + None, + &pk, + crate::pfsys::TranscriptType::EVM, + halo2_proofs::poly::kzg::strategy::SingleStrategy::new(¶ms), + // use safe mode to verify that the proof is correct + CheckMode::SAFE, + None, + ); + + assert!(prover.is_ok()); + + let proof = prover.unwrap(); + + let strategy = + halo2_proofs::poly::kzg::strategy::SingleStrategy::new(params.verifier_params()); + let vk = pk.get_vk(); + let result = + crate::pfsys::verify_proof_circuit_kzg(params.verifier_params(), proof, vk, strategy); + + assert!(result.is_ok()); + + println!("done."); + } +} + +#[cfg(test)] +#[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))] +mod matmul_col_ultra_overflow { + use halo2_proofs::poly::commitment::ParamsProver; + + use super::*; + + const K: usize = 4; + const LEN: usize = 20; + + #[derive(Clone)] + struct MatmulCircuit { + inputs: [ValTensor; 2], + _marker: PhantomData, + } + + impl Circuit for MatmulCircuit { + type Config = BaseConfig; + type FloorPlanner = SimpleFloorPlanner; + type Params = TestParams; + + fn without_witnesses(&self) -> Self { + self.clone() + } + + fn configure(cs: &mut ConstraintSystem) -> Self::Config { + let a = VarTensor::new_advice(cs, K, 1, LEN * LEN * LEN); + let b = VarTensor::new_advice(cs, K, 1, LEN * LEN * LEN); + let output = VarTensor::new_advice(cs, K, 1, LEN * LEN * LEN); + Self::Config::configure(cs, &[a, b], &output, CheckMode::SAFE) + } + + fn synthesize( + &self, + mut config: Self::Config, + mut layouter: impl Layouter, + ) -> Result<(), Error> { + layouter + .assign_region( + || "", + |region| { + let mut region = RegionCtx::new(region, 0, 1); + config + .layout( + &mut region, + &self.inputs.clone(), + Box::new(PolyOp::Einsum { + equation: "ij,jk->ik".to_string(), + }), + ) + .map_err(|_| Error::Synthesis) + }, + ) + .unwrap(); + Ok(()) + } + } + + #[test] + #[ignore] + fn matmulcircuit() { + // get some logs fam + crate::logger::init_logger(); + // parameters + let mut a = Tensor::from((0..LEN * LEN).map(|i| Value::known(F::from((i + 1) as u64)))); + a.reshape(&[LEN, LEN]).unwrap(); + + let mut w = Tensor::from((0..LEN).map(|i| Value::known(F::from((i + 1) as u64)))); + w.reshape(&[LEN, 1]).unwrap(); + + let circuit = MatmulCircuit:: { + inputs: [ValTensor::from(a), ValTensor::from(w)], + _marker: PhantomData, + }; + + let params = crate::pfsys::srs::gen_srs::< + halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme<_>, + >(K as u32); + + let pk = crate::pfsys::create_keys::< + halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme, + F, + MatmulCircuit, + >(&circuit, ¶ms) + .unwrap(); + + let prover = crate::pfsys::create_proof_circuit_kzg( + circuit.clone(), + ¶ms, + None, + &pk, + crate::pfsys::TranscriptType::EVM, + halo2_proofs::poly::kzg::strategy::SingleStrategy::new(¶ms), + // use safe mode to verify that the proof is correct + CheckMode::SAFE, + None, + ); + + assert!(prover.is_ok()); + + let proof = prover.unwrap(); + + let strategy = + halo2_proofs::poly::kzg::strategy::SingleStrategy::new(params.verifier_params()); + let vk = pk.get_vk(); + let result = + crate::pfsys::verify_proof_circuit_kzg(params.verifier_params(), proof, vk, strategy); + + assert!(result.is_ok()); + + println!("done."); + } +} + +#[cfg(test)] +mod dot { + use ops::poly::PolyOp; + + use super::*; + + const K: usize = 4; + const LEN: usize = 4; + + #[derive(Clone)] + struct MyCircuit { + inputs: [ValTensor; 2], + _marker: PhantomData, + } + + impl Circuit for MyCircuit { + type Config = BaseConfig; + type FloorPlanner = SimpleFloorPlanner; + type Params = TestParams; + + fn without_witnesses(&self) -> Self { + self.clone() + } + + fn configure(cs: &mut ConstraintSystem) -> Self::Config { + let a = VarTensor::new_advice(cs, K, 1, LEN); + let b = VarTensor::new_advice(cs, K, 1, LEN); + let output = VarTensor::new_advice(cs, K, 1, LEN); + + Self::Config::configure(cs, &[a, b], &output, CheckMode::SAFE) + } + + fn synthesize( + &self, + mut config: Self::Config, + mut layouter: impl Layouter, + ) -> Result<(), Error> { + layouter + .assign_region( + || "", + |region| { + let mut region = RegionCtx::new(region, 0, 1); + config + .layout( + &mut region, + &self.inputs.clone(), + Box::new(PolyOp::Einsum { + equation: "i,i->".to_string(), + }), + ) + .map_err(|_| Error::Synthesis) + }, + ) + .unwrap(); + Ok(()) + } + } + + #[test] + fn dotcircuit() { + // parameters + let a = Tensor::from((0..LEN).map(|i| Value::known(F::from(i as u64 + 1)))); + + let b = Tensor::from((0..LEN).map(|i| Value::known(F::from(i as u64 + 1)))); + + let circuit = MyCircuit:: { + inputs: [ValTensor::from(a), ValTensor::from(b)], + _marker: PhantomData, + }; + + let prover = MockProver::run(K as u32, &circuit, vec![]).unwrap(); + prover.assert_satisfied_par(); + } +} + +#[cfg(test)] +mod dot_col_overflow_triple_col { + use super::*; + + const K: usize = 4; + const LEN: usize = 50; + + #[derive(Clone)] + struct MyCircuit { + inputs: [ValTensor; 2], + _marker: PhantomData, + } + + impl Circuit for MyCircuit { + type Config = BaseConfig; + type FloorPlanner = SimpleFloorPlanner; + type Params = TestParams; + + fn without_witnesses(&self) -> Self { + self.clone() + } + + fn configure(cs: &mut ConstraintSystem) -> Self::Config { + // used for constants in the padding + let _fixed = cs.fixed_column(); + cs.enable_constant(_fixed); + + let a = VarTensor::new_advice(cs, K, 3, LEN); + let b = VarTensor::new_advice(cs, K, 3, LEN); + let output = VarTensor::new_advice(cs, K, 3, LEN); + + Self::Config::configure(cs, &[a, b], &output, CheckMode::SAFE) + } + + fn synthesize( + &self, + mut config: Self::Config, + mut layouter: impl Layouter, + ) -> Result<(), Error> { + layouter + .assign_region( + || "", + |region| { + let mut region = RegionCtx::new(region, 0, 3); + config + .layout( + &mut region, + &self.inputs.clone(), + Box::new(PolyOp::Einsum { + equation: "i,i->".to_string(), + }), + ) + .map_err(|_| Error::Synthesis) + }, + ) + .unwrap(); + Ok(()) + } + } + + #[test] + fn dotcircuit() { + // parameters + let a = Tensor::from((0..LEN).map(|i| Value::known(F::from(i as u64 + 1)))); + + let b = Tensor::from((0..LEN).map(|i| Value::known(F::from(i as u64 + 1)))); + + let circuit = MyCircuit:: { + inputs: [ValTensor::from(a), ValTensor::from(b)], + _marker: PhantomData, + }; + + let prover = MockProver::run(K as u32, &circuit, vec![]).unwrap(); + prover.assert_satisfied_par(); + } +} + +#[cfg(test)] +mod dot_col_overflow { + use super::*; + + const K: usize = 4; + const LEN: usize = 50; + + #[derive(Clone)] + struct MyCircuit { + inputs: [ValTensor; 2], + _marker: PhantomData, + } + + impl Circuit for MyCircuit { + type Config = BaseConfig; + type FloorPlanner = SimpleFloorPlanner; + type Params = TestParams; + + fn without_witnesses(&self) -> Self { + self.clone() + } + + fn configure(cs: &mut ConstraintSystem) -> Self::Config { + let a = VarTensor::new_advice(cs, K, 1, LEN); + let b = VarTensor::new_advice(cs, K, 1, LEN); + let output = VarTensor::new_advice(cs, K, 1, LEN); + + Self::Config::configure(cs, &[a, b], &output, CheckMode::SAFE) + } + + fn synthesize( + &self, + mut config: Self::Config, + mut layouter: impl Layouter, + ) -> Result<(), Error> { + layouter + .assign_region( + || "", + |region| { + let mut region = RegionCtx::new(region, 0, 1); + config + .layout( + &mut region, + &self.inputs.clone(), + Box::new(PolyOp::Einsum { + equation: "i,i->".to_string(), + }), + ) + .map_err(|_| Error::Synthesis) + }, + ) + .unwrap(); + Ok(()) + } + } + + #[test] + fn dotcircuit() { + // parameters + let a = Tensor::from((0..LEN).map(|i| Value::known(F::from(i as u64 + 1)))); + + let b = Tensor::from((0..LEN).map(|i| Value::known(F::from(i as u64 + 1)))); + + let circuit = MyCircuit:: { + inputs: [ValTensor::from(a), ValTensor::from(b)], + _marker: PhantomData, + }; + + let prover = MockProver::run(K as u32, &circuit, vec![]).unwrap(); + prover.assert_satisfied_par(); + } +} + +#[cfg(test)] +mod sum { + use super::*; + + const K: usize = 4; + const LEN: usize = 4; + + #[derive(Clone)] + struct MyCircuit { + inputs: [ValTensor; 1], + _marker: PhantomData, + } + + impl Circuit for MyCircuit { + type Config = BaseConfig; + type FloorPlanner = SimpleFloorPlanner; + type Params = TestParams; + + fn without_witnesses(&self) -> Self { + self.clone() + } + + fn configure(cs: &mut ConstraintSystem) -> Self::Config { + let a = VarTensor::new_advice(cs, K, 1, LEN); + let b = VarTensor::new_advice(cs, K, 1, LEN); + let output = VarTensor::new_advice(cs, K, 1, LEN); + + Self::Config::configure(cs, &[a, b], &output, CheckMode::SAFE) + } + + fn synthesize( + &self, + mut config: Self::Config, + mut layouter: impl Layouter, + ) -> Result<(), Error> { + layouter + .assign_region( + || "", + |region| { + let mut region = RegionCtx::new(region, 0, 1); + config + .layout( + &mut region, + &self.inputs.clone(), + Box::new(PolyOp::Sum { axes: vec![0] }), + ) + .map_err(|_| Error::Synthesis) + }, + ) + .unwrap(); + Ok(()) + } + } + + #[test] + fn sumcircuit() { + // parameters + let a = Tensor::from((0..LEN).map(|i| Value::known(F::from(i as u64 + 1)))); + + let circuit = MyCircuit:: { + inputs: [ValTensor::from(a)], + _marker: PhantomData, + }; + + let prover = MockProver::run(K as u32, &circuit, vec![]).unwrap(); + prover.assert_satisfied_par(); + } +} + +#[cfg(test)] +mod sum_col_overflow_double_col { + use super::*; + + const K: usize = 4; + const LEN: usize = 20; + const NUM_INNER_COLS: usize = 2; + + #[derive(Clone)] + struct MyCircuit { + inputs: [ValTensor; 1], + _marker: PhantomData, + } + + impl Circuit for MyCircuit { + type Config = BaseConfig; + type FloorPlanner = SimpleFloorPlanner; + type Params = TestParams; + + fn without_witnesses(&self) -> Self { + self.clone() + } + + fn configure(cs: &mut ConstraintSystem) -> Self::Config { + let a = VarTensor::new_advice(cs, K, NUM_INNER_COLS, LEN); + let b = VarTensor::new_advice(cs, K, NUM_INNER_COLS, LEN); + let output = VarTensor::new_advice(cs, K, NUM_INNER_COLS, LEN); + + Self::Config::configure(cs, &[a, b], &output, CheckMode::SAFE) + } + + fn synthesize( + &self, + mut config: Self::Config, + mut layouter: impl Layouter, + ) -> Result<(), Error> { + layouter + .assign_region( + || "", + |region| { + let mut region = RegionCtx::new(region, 0, NUM_INNER_COLS); + config + .layout( + &mut region, + &self.inputs.clone(), + Box::new(PolyOp::Sum { axes: vec![0] }), + ) + .map_err(|_| Error::Synthesis) + }, + ) + .unwrap(); + Ok(()) + } + } + + #[test] + fn sumcircuit() { + // parameters + let a = Tensor::from((0..LEN).map(|i| Value::known(F::from(i as u64 + 1)))); + + let circuit = MyCircuit:: { + inputs: [ValTensor::from(a)], + _marker: PhantomData, + }; + + let prover = MockProver::run(K as u32, &circuit, vec![]).unwrap(); + prover.assert_satisfied_par(); + } +} + +#[cfg(test)] +mod sum_col_overflow { + use super::*; + + const K: usize = 4; + const LEN: usize = 20; + + #[derive(Clone)] + struct MyCircuit { + inputs: [ValTensor; 1], + _marker: PhantomData, + } + + impl Circuit for MyCircuit { + type Config = BaseConfig; + type FloorPlanner = SimpleFloorPlanner; + type Params = TestParams; + + fn without_witnesses(&self) -> Self { + self.clone() + } + + fn configure(cs: &mut ConstraintSystem) -> Self::Config { + let a = VarTensor::new_advice(cs, K, 1, LEN); + let b = VarTensor::new_advice(cs, K, 1, LEN); + let output = VarTensor::new_advice(cs, K, 1, LEN); + + Self::Config::configure(cs, &[a, b], &output, CheckMode::SAFE) + } + + fn synthesize( + &self, + mut config: Self::Config, + mut layouter: impl Layouter, + ) -> Result<(), Error> { + layouter + .assign_region( + || "", + |region| { + let mut region = RegionCtx::new(region, 0, 1); + config + .layout( + &mut region, + &self.inputs.clone(), + Box::new(PolyOp::Sum { axes: vec![0] }), + ) + .map_err(|_| Error::Synthesis) + }, + ) + .unwrap(); + Ok(()) + } + } + + #[test] + fn sumcircuit() { + // parameters + let a = Tensor::from((0..LEN).map(|i| Value::known(F::from(i as u64 + 1)))); + + let circuit = MyCircuit:: { + inputs: [ValTensor::from(a)], + _marker: PhantomData, + }; + + let prover = MockProver::run(K as u32, &circuit, vec![]).unwrap(); + prover.assert_satisfied_par(); + } +} + +#[cfg(test)] +mod composition { + + use super::*; + + const K: usize = 9; + const LEN: usize = 4; + + #[derive(Clone)] + struct MyCircuit { + inputs: [ValTensor; 2], + _marker: PhantomData, + } + + impl Circuit for MyCircuit { + type Config = BaseConfig; + type FloorPlanner = SimpleFloorPlanner; + type Params = TestParams; + + fn without_witnesses(&self) -> Self { + self.clone() + } + + fn configure(cs: &mut ConstraintSystem) -> Self::Config { + let a = VarTensor::new_advice(cs, K, 1, LEN); + let b = VarTensor::new_advice(cs, K, 1, LEN); + let output = VarTensor::new_advice(cs, K, 1, LEN); + + Self::Config::configure(cs, &[a, b], &output, CheckMode::SAFE) + } + + fn synthesize( + &self, + mut config: Self::Config, + mut layouter: impl Layouter, + ) -> Result<(), Error> { + // lots of stacked dot products + layouter + .assign_region( + || "", + |region| { + let mut region = RegionCtx::new(region, 0, 1); + let _ = config + .layout( + &mut region, + &self.inputs.clone(), + Box::new(PolyOp::Einsum { + equation: "i,i->".to_string(), + }), + ) + .unwrap(); + let _ = config + .layout( + &mut region, + &self.inputs.clone(), + Box::new(PolyOp::Einsum { + equation: "i,i->".to_string(), + }), + ) + .unwrap(); + config + .layout( + &mut region, + &self.inputs.clone(), + Box::new(PolyOp::Einsum { + equation: "i,i->".to_string(), + }), + ) + .map_err(|_| Error::Synthesis) + }, + ) + .unwrap(); + Ok(()) + } + } + + #[test] + fn dotcircuit() { + // parameters + let a = Tensor::from((0..LEN).map(|i| Value::known(F::from(i as u64 + 1)))); + + let b = Tensor::from((0..LEN).map(|i| Value::known(F::from(i as u64 + 1)))); + + let circuit = MyCircuit:: { + inputs: [ValTensor::from(a), ValTensor::from(b)], + _marker: PhantomData, + }; + + let prover = MockProver::run(K as u32, &circuit, vec![]).unwrap(); + prover.assert_satisfied_par(); + } +} + +#[cfg(test)] +mod conv { + + use super::*; + + const K: usize = 22; + const LEN: usize = 100; + + #[derive(Clone)] + struct ConvCircuit { + inputs: Vec>, + _marker: PhantomData, + } + + impl Circuit for ConvCircuit { + type Config = BaseConfig; + type FloorPlanner = SimpleFloorPlanner; + type Params = TestParams; + + fn without_witnesses(&self) -> Self { + self.clone() + } + + fn configure(cs: &mut ConstraintSystem) -> Self::Config { + let a = VarTensor::new_advice(cs, K, 1, (LEN + 1) * LEN); + let b = VarTensor::new_advice(cs, K, 1, (LEN + 1) * LEN); + let output = VarTensor::new_advice(cs, K, 1, (LEN + 1) * LEN); + Self::Config::configure(cs, &[a, b], &output, CheckMode::SAFE) + } + + fn synthesize( + &self, + mut config: Self::Config, + mut layouter: impl Layouter, + ) -> Result<(), Error> { + layouter + .assign_region( + || "", + |region| { + let mut region = RegionCtx::new(region, 0, 1); + config + .layout( + &mut region, + &[self.inputs[0].clone().try_into().unwrap()], + Box::new(PolyOp::Conv { + kernel: self.inputs[1].clone(), + bias: None, + padding: [(1, 1); 2], + stride: (2, 2), + }), + ) + .map_err(|_| Error::Synthesis) + }, + ) + .unwrap(); + Ok(()) + } + } + + #[test] + fn convcircuit() { + // parameters + let kernel_height = 2; + let kernel_width = 3; + let image_height = 5; + let image_width = 7; + let in_channels = 3; + let out_channels = 2; + + let mut image = + Tensor::from((0..in_channels * image_height * image_width).map(|_| F::random(OsRng))); + image + .reshape(&[1, in_channels, image_height, image_width]) + .unwrap(); + image.set_visibility(&crate::graph::Visibility::Private); + + let mut kernels = Tensor::from( + (0..{ out_channels * in_channels * kernel_height * kernel_width }) + .map(|_| F::random(OsRng)), + ); + kernels + .reshape(&[out_channels, in_channels, kernel_height, kernel_width]) + .unwrap(); + kernels.set_visibility(&crate::graph::Visibility::Private); + + let mut bias = Tensor::from((0..{ out_channels }).map(|_| F::random(OsRng))); + bias.set_visibility(&crate::graph::Visibility::Private); + + let circuit = ConvCircuit:: { + inputs: [image, kernels, bias].to_vec(), + _marker: PhantomData, + }; + + let prover = MockProver::run(K as u32, &circuit, vec![]).unwrap(); + prover.assert_satisfied_par(); + } + + #[test] + fn convcircuitnobias() { + // parameters + let kernel_height = 2; + let kernel_width = 2; + let image_height = 4; + let image_width = 5; + let in_channels = 3; + let out_channels = 2; + + let mut image = + Tensor::from((0..in_channels * image_height * image_width).map(|i| F::from(i as u64))); + image + .reshape(&[1, in_channels, image_height, image_width]) + .unwrap(); + image.set_visibility(&crate::graph::Visibility::Private); + + let mut kernels = Tensor::from( + (0..{ out_channels * in_channels * kernel_height * kernel_width }) + .map(|i| F::from(i as u64)), + ); + kernels + .reshape(&[out_channels, in_channels, kernel_height, kernel_width]) + .unwrap(); + kernels.set_visibility(&crate::graph::Visibility::Private); + + let circuit = ConvCircuit:: { + inputs: [image, kernels].to_vec(), + _marker: PhantomData, + }; + + let prover = MockProver::run(K as u32, &circuit, vec![]).unwrap(); + prover.assert_satisfied_par(); + } +} + +#[cfg(test)] +#[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))] +mod conv_col_ultra_overflow { + use halo2_proofs::poly::commitment::ParamsProver; + + use super::*; + + const K: usize = 4; + const LEN: usize = 28; + + #[derive(Clone)] + struct ConvCircuit { + image: ValTensor, + kernel: Tensor, + _marker: PhantomData, + } + + impl Circuit for ConvCircuit { + type Config = BaseConfig; + type FloorPlanner = SimpleFloorPlanner; + type Params = TestParams; + + fn without_witnesses(&self) -> Self { + self.clone() + } + + fn configure(cs: &mut ConstraintSystem) -> Self::Config { + let a = VarTensor::new_advice(cs, K, 1, LEN * LEN * LEN); + let b = VarTensor::new_advice(cs, K, 1, LEN * LEN * LEN); + let output = VarTensor::new_advice(cs, K, 1, LEN * LEN * LEN); + Self::Config::configure(cs, &[a, b], &output, CheckMode::SAFE) + } + + fn synthesize( + &self, + mut config: Self::Config, + mut layouter: impl Layouter, + ) -> Result<(), Error> { + layouter + .assign_region( + || "", + |region| { + let mut region = RegionCtx::new(region, 0, 1); + config + .layout( + &mut region, + &[self.image.clone()], + Box::new(PolyOp::Conv { + kernel: self.kernel.clone(), + bias: None, + padding: [(1, 1); 2], + stride: (2, 2), + }), + ) + .map_err(|_| Error::Synthesis) + }, + ) + .unwrap(); + Ok(()) + } + } + + #[test] + #[ignore] + fn conv_circuit() { + // parameters + let kernel_height = 2; + let kernel_width = 2; + let image_height = LEN; + let image_width = LEN; + let in_channels = 3; + let out_channels = 2; + + // get some logs fam + crate::logger::init_logger(); + let mut image = + Tensor::from((0..in_channels * image_height * image_width).map(|i| F::from(i as u64))); + image + .reshape(&[1, in_channels, image_height, image_width]) + .unwrap(); + image.set_visibility(&crate::graph::Visibility::Private); + + let mut kernels = Tensor::from( + (0..{ out_channels * in_channels * kernel_height * kernel_width }) + .map(|i| F::from(i as u64)), + ); + kernels + .reshape(&[out_channels, in_channels, kernel_height, kernel_width]) + .unwrap(); + kernels.set_visibility(&crate::graph::Visibility::Private); + + let circuit = ConvCircuit:: { + image: ValTensor::try_from(image).unwrap(), + kernel: kernels, + _marker: PhantomData, + }; + + let params = crate::pfsys::srs::gen_srs::< + halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme<_>, + >(K as u32); + + let pk = crate::pfsys::create_keys::< + halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme, + F, + ConvCircuit, + >(&circuit, ¶ms) + .unwrap(); + + let prover = crate::pfsys::create_proof_circuit_kzg( + circuit.clone(), + ¶ms, + None, + &pk, + crate::pfsys::TranscriptType::EVM, + halo2_proofs::poly::kzg::strategy::SingleStrategy::new(¶ms), + // use safe mode to verify that the proof is correct + CheckMode::SAFE, + None, + ); + + assert!(prover.is_ok()); + + let proof = prover.unwrap(); + + let strategy = + halo2_proofs::poly::kzg::strategy::SingleStrategy::new(params.verifier_params()); + let vk = pk.get_vk(); + let result = + crate::pfsys::verify_proof_circuit_kzg(params.verifier_params(), proof, vk, strategy); + + assert!(result.is_ok()); + + println!("done."); + } +} + +#[cfg(test)] +// not wasm 32 unknown +#[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))] +mod conv_relu_col_ultra_overflow { + use halo2_proofs::poly::commitment::ParamsProver; + + use super::*; + + const K: usize = 4; + const LEN: usize = 28; + + #[derive(Clone)] + struct ConvCircuit { + image: ValTensor, + kernel: Tensor, + _marker: PhantomData, + } + + impl Circuit for ConvCircuit { + type Config = BaseConfig; + type FloorPlanner = SimpleFloorPlanner; + type Params = TestParams; + + fn without_witnesses(&self) -> Self { + self.clone() + } + + fn configure(cs: &mut ConstraintSystem) -> Self::Config { + let a = VarTensor::new_advice(cs, K, 1, LEN * LEN * LEN); + let b = VarTensor::new_advice(cs, K, 1, LEN * LEN * LEN); + let output = VarTensor::new_advice(cs, K, 1, LEN * LEN * LEN); + let mut base_config = + Self::Config::configure(cs, &[a.clone(), b.clone()], &output, CheckMode::SAFE); + // sets up a new relu table + base_config + .configure_lookup(cs, &b, &output, &a, (-3, 3), K, &LookupOp::ReLU) + .unwrap(); + base_config.clone() + } + + fn synthesize( + &self, + mut config: Self::Config, + mut layouter: impl Layouter, + ) -> Result<(), Error> { + config.layout_tables(&mut layouter).unwrap(); + layouter + .assign_region( + || "", + |region| { + let mut region = RegionCtx::new(region, 0, 1); + let output = config + .layout( + &mut region, + &[self.image.clone()], + Box::new(PolyOp::Conv { + kernel: self.kernel.clone(), + bias: None, + padding: [(1, 1); 2], + stride: (2, 2), + }), + ) + .map_err(|_| Error::Synthesis); + let _output = config + .layout( + &mut region, + &[output.unwrap().unwrap()], + Box::new(LookupOp::ReLU), + ) + .unwrap(); + Ok(()) + }, + ) + .unwrap(); + + Ok(()) + } + } + + #[test] + #[ignore] + fn conv_relu_circuit() { + // parameters + let kernel_height = 2; + let kernel_width = 2; + let image_height = LEN; + let image_width = LEN; + let in_channels = 3; + let out_channels = 2; + + // get some logs fam + crate::logger::init_logger(); + let mut image = + Tensor::from((0..in_channels * image_height * image_width).map(|_| F::from(0))); + image + .reshape(&[1, in_channels, image_height, image_width]) + .unwrap(); + image.set_visibility(&crate::graph::Visibility::Private); + + let mut kernels = Tensor::from( + (0..{ out_channels * in_channels * kernel_height * kernel_width }).map(|_| F::from(0)), + ); + kernels + .reshape(&[out_channels, in_channels, kernel_height, kernel_width]) + .unwrap(); + kernels.set_visibility(&crate::graph::Visibility::Private); + + let circuit = ConvCircuit:: { + image: ValTensor::try_from(image).unwrap(), + kernel: kernels, + _marker: PhantomData, + }; + + let params = crate::pfsys::srs::gen_srs::< + halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme<_>, + >(K as u32); + + let pk = crate::pfsys::create_keys::< + halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme, + F, + ConvCircuit, + >(&circuit, ¶ms) + .unwrap(); + + let prover = crate::pfsys::create_proof_circuit_kzg( + circuit.clone(), + ¶ms, + None, + &pk, + crate::pfsys::TranscriptType::EVM, + halo2_proofs::poly::kzg::strategy::SingleStrategy::new(¶ms), + // use safe mode to verify that the proof is correct + CheckMode::SAFE, + None, + ); + + assert!(prover.is_ok()); + + let proof = prover.unwrap(); + + let strategy = + halo2_proofs::poly::kzg::strategy::SingleStrategy::new(params.verifier_params()); + let vk = pk.get_vk(); + let result = + crate::pfsys::verify_proof_circuit_kzg(params.verifier_params(), proof, vk, strategy); + + assert!(result.is_ok()); + + println!("done."); + } +} + +#[cfg(test)] +mod sumpool { + + use super::*; + + const K: usize = 20; + const LEN: usize = 100; + + #[derive(Clone)] + struct ConvCircuit { + inputs: Vec>, + _marker: PhantomData, + } + + impl Circuit for ConvCircuit { + type Config = BaseConfig; + type FloorPlanner = SimpleFloorPlanner; + type Params = TestParams; + + fn without_witnesses(&self) -> Self { + self.clone() + } + + fn configure(cs: &mut ConstraintSystem) -> Self::Config { + let a = VarTensor::new_advice(cs, K, 1, (LEN + 1) * LEN); + let b = VarTensor::new_advice(cs, K, 1, (LEN + 1) * LEN); + let output = VarTensor::new_advice(cs, K, 1, (LEN + 1) * LEN); + VarTensor::constant_cols(cs, K, 2, false); + Self::Config::configure(cs, &[a, b], &output, CheckMode::SAFE) + } + + fn synthesize( + &self, + mut config: Self::Config, + mut layouter: impl Layouter, + ) -> Result<(), Error> { + layouter + .assign_region( + || "", + |region| { + let mut region = RegionCtx::new(region, 0, 1); + config + .layout( + &mut region, + &self.inputs.clone(), + Box::new(PolyOp::SumPool { + padding: [(0, 0); 2], + stride: (1, 1), + kernel_shape: (3, 3), + }), + ) + .map_err(|_| Error::Synthesis) + }, + ) + .unwrap(); + Ok(()) + } + } + + #[test] + fn sumpoolcircuit() { + let image_height = 5; + let image_width = 5; + let in_channels = 1; + + let mut image = Tensor::from( + (0..in_channels * image_height * image_width).map(|_| Value::known(F::random(OsRng))), + ); + image + .reshape(&[1, in_channels, image_height, image_width]) + .unwrap(); + + let circuit = ConvCircuit:: { + inputs: [ValTensor::from(image)].to_vec(), + _marker: PhantomData, + }; + + let prover = MockProver::run(K as u32, &circuit, vec![]).unwrap(); + prover.assert_satisfied_par(); + } +} + +#[cfg(test)] +mod add_w_shape_casting { + use super::*; + + const K: usize = 4; + const LEN: usize = 4; + + #[derive(Clone)] + struct MyCircuit { + inputs: [ValTensor; 2], + _marker: PhantomData, + } + + impl Circuit for MyCircuit { + type Config = BaseConfig; + type FloorPlanner = SimpleFloorPlanner; + type Params = TestParams; + + fn without_witnesses(&self) -> Self { + self.clone() + } + + fn configure(cs: &mut ConstraintSystem) -> Self::Config { + let a = VarTensor::new_advice(cs, K, 1, LEN); + let b = VarTensor::new_advice(cs, K, 1, LEN); + let output = VarTensor::new_advice(cs, K, 1, LEN); + + Self::Config::configure(cs, &[a, b], &output, CheckMode::SAFE) + } + + fn synthesize( + &self, + mut config: Self::Config, + mut layouter: impl Layouter, + ) -> Result<(), Error> { + layouter + .assign_region( + || "", + |region| { + let mut region = RegionCtx::new(region, 0, 1); + config + .layout(&mut region, &self.inputs.clone(), Box::new(PolyOp::Add)) + .map_err(|_| Error::Synthesis) + }, + ) + .unwrap(); + Ok(()) + } + } + + #[test] + fn addcircuit() { + // parameters + let a = Tensor::from((0..LEN).map(|i| Value::known(F::from(i as u64 + 1)))); + + let b = Tensor::from((0..1).map(|i| Value::known(F::from(i as u64 + 1)))); + + let circuit = MyCircuit:: { + inputs: [ValTensor::from(a), ValTensor::from(b)], + _marker: PhantomData, + }; + + let prover = MockProver::run(K as u32, &circuit, vec![]).unwrap(); + prover.assert_satisfied_par(); + } +} + +#[cfg(test)] +mod add { + use super::*; + + const K: usize = 4; + const LEN: usize = 4; + + #[derive(Clone)] + struct MyCircuit { + inputs: [ValTensor; 2], + _marker: PhantomData, + } + + impl Circuit for MyCircuit { + type Config = BaseConfig; + type FloorPlanner = SimpleFloorPlanner; + type Params = TestParams; + + fn without_witnesses(&self) -> Self { + self.clone() + } + + fn configure(cs: &mut ConstraintSystem) -> Self::Config { + let a = VarTensor::new_advice(cs, K, 1, LEN); + let b = VarTensor::new_advice(cs, K, 1, LEN); + let output = VarTensor::new_advice(cs, K, 1, LEN); + + Self::Config::configure(cs, &[a, b], &output, CheckMode::SAFE) + } + + fn synthesize( + &self, + mut config: Self::Config, + mut layouter: impl Layouter, + ) -> Result<(), Error> { + layouter + .assign_region( + || "", + |region| { + let mut region = RegionCtx::new(region, 0, 1); + config + .layout(&mut region, &self.inputs.clone(), Box::new(PolyOp::Add)) + .map_err(|_| Error::Synthesis) + }, + ) + .unwrap(); + Ok(()) + } + } + + #[test] + fn addcircuit() { + // parameters + let a = Tensor::from((0..LEN).map(|i| Value::known(F::from(i as u64 + 1)))); + + let b = Tensor::from((0..LEN).map(|i| Value::known(F::from(i as u64 + 1)))); + + let circuit = MyCircuit:: { + inputs: [ValTensor::from(a), ValTensor::from(b)], + _marker: PhantomData, + }; + + let prover = MockProver::run(K as u32, &circuit, vec![]).unwrap(); + prover.assert_satisfied_par(); + } +} + +#[cfg(test)] +mod add_with_overflow { + use super::*; + + const K: usize = 4; + const LEN: usize = 50; + + #[derive(Clone)] + struct MyCircuit { + inputs: [ValTensor; 2], + _marker: PhantomData, + } + + impl Circuit for MyCircuit { + type Config = BaseConfig; + type FloorPlanner = SimpleFloorPlanner; + type Params = TestParams; + + fn without_witnesses(&self) -> Self { + self.clone() + } + + fn configure(cs: &mut ConstraintSystem) -> Self::Config { + let a = VarTensor::new_advice(cs, K, 1, LEN); + let b = VarTensor::new_advice(cs, K, 1, LEN); + let output = VarTensor::new_advice(cs, K, 1, LEN); + + Self::Config::configure(cs, &[a, b], &output, CheckMode::SAFE) + } + + fn synthesize( + &self, + mut config: Self::Config, + mut layouter: impl Layouter, + ) -> Result<(), Error> { + layouter + .assign_region( + || "", + |region| { + let mut region = RegionCtx::new(region, 0, 1); + config + .layout(&mut region, &self.inputs.clone(), Box::new(PolyOp::Add)) + .map_err(|_| Error::Synthesis) + }, + ) + .unwrap(); + Ok(()) + } + } + + #[test] + fn addcircuit() { + // parameters + let a = Tensor::from((0..LEN).map(|i| Value::known(F::from(i as u64 + 1)))); + + let b = Tensor::from((0..LEN).map(|i| Value::known(F::from(i as u64 + 1)))); + + let circuit = MyCircuit:: { + inputs: [ValTensor::from(a), ValTensor::from(b)], + _marker: PhantomData, + }; + + let prover = MockProver::run(K as u32, &circuit, vec![]).unwrap(); + prover.assert_satisfied_par(); + } +} + +#[cfg(test)] +mod add_with_overflow_and_poseidon { + use halo2curves::bn256::Fr; + + use crate::circuit::modules::{ + poseidon::{ + spec::{PoseidonSpec, POSEIDON_RATE, POSEIDON_WIDTH}, + PoseidonChip, PoseidonConfig, + }, + Module, ModulePlanner, + }; + + use super::*; + + const K: usize = 15; + const LEN: usize = 50; + const WIDTH: usize = POSEIDON_WIDTH; + const RATE: usize = POSEIDON_RATE; + + #[derive(Debug, Clone)] + struct MyCircuitConfig { + base: BaseConfig, + poseidon: PoseidonConfig, + } + + #[derive(Clone)] + struct MyCircuit { + inputs: [ValTensor; 2], + } + + impl Circuit for MyCircuit { + type Config = MyCircuitConfig; + type FloorPlanner = ModulePlanner; + type Params = TestParams; + + fn without_witnesses(&self) -> Self { + self.clone() + } + + fn configure(cs: &mut ConstraintSystem) -> Self::Config { + let a = VarTensor::new_advice(cs, K, 1, LEN); + let b = VarTensor::new_advice(cs, K, 1, LEN); + let output = VarTensor::new_advice(cs, K, 1, LEN); + + let base = BaseConfig::configure(cs, &[a, b], &output, CheckMode::SAFE); + VarTensor::constant_cols(cs, K, 2, false); + + let poseidon = PoseidonChip::::configure(cs, ()); + + MyCircuitConfig { base, poseidon } + } + + fn synthesize( + &self, + mut config: Self::Config, + mut layouter: impl Layouter, + ) -> Result<(), Error> { + let poseidon_chip: PoseidonChip = + PoseidonChip::new(config.poseidon.clone()); + + let assigned_inputs_a = poseidon_chip.layout(&mut layouter, &self.inputs[0..1], 0)?; + let assigned_inputs_b = poseidon_chip.layout(&mut layouter, &self.inputs[1..2], 1)?; + + layouter.assign_region(|| "_new_module", |_| Ok(()))?; + + let inputs = vec![assigned_inputs_a, assigned_inputs_b]; + + layouter.assign_region( + || "model", + |region| { + let mut region = RegionCtx::new(region, 0, 1); + config + .base + .layout(&mut region, &inputs, Box::new(PolyOp::Add)) + .map_err(|_| Error::Synthesis) + }, + )?; + + Ok(()) + } + } + + #[test] + fn addcircuit() { + let a = (0..LEN) + .map(|i| halo2curves::bn256::Fr::from(i as u64 + 1)) + .collect::>(); + let b = (0..LEN) + .map(|i| halo2curves::bn256::Fr::from(i as u64 + 1)) + .collect::>(); + let commitment_a = + PoseidonChip::::run(a.clone()).unwrap()[0][0]; + + let commitment_b = + PoseidonChip::::run(b.clone()).unwrap()[0][0]; + + // parameters + let a = Tensor::from(a.into_iter().map(Value::known)); + let b = Tensor::from(b.into_iter().map(Value::known)); + let circuit = MyCircuit { + inputs: [ValTensor::from(a), ValTensor::from(b)], + }; + + let prover = + MockProver::run(K as u32, &circuit, vec![vec![commitment_a, commitment_b]]).unwrap(); + prover.assert_satisfied_par(); + } + + #[test] + fn addcircuit_bad_hashes() { + let a = (0..LEN) + .map(|i| halo2curves::bn256::Fr::from(i as u64 + 1)) + .collect::>(); + let b = (0..LEN) + .map(|i| halo2curves::bn256::Fr::from(i as u64 + 1)) + .collect::>(); + let commitment_a = PoseidonChip::::run(a.clone()) + .unwrap()[0][0] + + Fr::one(); + + let commitment_b = PoseidonChip::::run(b.clone()) + .unwrap()[0][0] + + Fr::one(); + + // parameters + let a = Tensor::from(a.into_iter().map(Value::known)); + let b = Tensor::from(b.into_iter().map(Value::known)); + let circuit = MyCircuit { + inputs: [ValTensor::from(a), ValTensor::from(b)], + }; + + let prover = + MockProver::run(K as u32, &circuit, vec![vec![commitment_a, commitment_b]]).unwrap(); + assert!(prover.verify().is_err()); + } +} + +#[cfg(test)] +mod sub { + use super::*; + + const K: usize = 4; + const LEN: usize = 4; + + #[derive(Clone)] + struct MyCircuit { + inputs: [ValTensor; 2], + _marker: PhantomData, + } + + impl Circuit for MyCircuit { + type Config = BaseConfig; + type FloorPlanner = SimpleFloorPlanner; + type Params = TestParams; + + fn without_witnesses(&self) -> Self { + self.clone() + } + + fn configure(cs: &mut ConstraintSystem) -> Self::Config { + let a = VarTensor::new_advice(cs, K, 1, LEN); + let b = VarTensor::new_advice(cs, K, 1, LEN); + let output = VarTensor::new_advice(cs, K, 1, LEN); + + Self::Config::configure(cs, &[a, b], &output, CheckMode::SAFE) + } + + fn synthesize( + &self, + mut config: Self::Config, + mut layouter: impl Layouter, + ) -> Result<(), Error> { + layouter + .assign_region( + || "", + |region| { + let mut region = RegionCtx::new(region, 0, 1); + config + .layout(&mut region, &self.inputs.clone(), Box::new(PolyOp::Sub)) + .map_err(|_| Error::Synthesis) + }, + ) + .unwrap(); + Ok(()) + } + } + + #[test] + fn subcircuit() { + // parameters + let a = Tensor::from((0..LEN).map(|i| Value::known(F::from(i as u64 + 1)))); + + let b = Tensor::from((0..LEN).map(|i| Value::known(F::from(i as u64 + 1)))); + + let circuit = MyCircuit:: { + inputs: [ValTensor::from(a), ValTensor::from(b)], + _marker: PhantomData, + }; + + let prover = MockProver::run(K as u32, &circuit, vec![]).unwrap(); + prover.assert_satisfied_par(); + } +} + +#[cfg(test)] +mod mult { + use super::*; + + const K: usize = 4; + const LEN: usize = 4; + + #[derive(Clone)] + struct MyCircuit { + inputs: [ValTensor; 2], + _marker: PhantomData, + } + + impl Circuit for MyCircuit { + type Config = BaseConfig; + type FloorPlanner = SimpleFloorPlanner; + type Params = TestParams; + + fn without_witnesses(&self) -> Self { + self.clone() + } + + fn configure(cs: &mut ConstraintSystem) -> Self::Config { + let a = VarTensor::new_advice(cs, K, 1, LEN); + let b = VarTensor::new_advice(cs, K, 1, LEN); + let output = VarTensor::new_advice(cs, K, 1, LEN); + + Self::Config::configure(cs, &[a, b], &output, CheckMode::SAFE) + } + + fn synthesize( + &self, + mut config: Self::Config, + mut layouter: impl Layouter, + ) -> Result<(), Error> { + layouter + .assign_region( + || "", + |region| { + let mut region = RegionCtx::new(region, 0, 1); + config + .layout(&mut region, &self.inputs.clone(), Box::new(PolyOp::Mult)) + .map_err(|_| Error::Synthesis) + }, + ) + .unwrap(); + Ok(()) + } + } + + #[test] + fn multcircuit() { + // parameters + let a = Tensor::from((0..LEN).map(|i| Value::known(F::from(i as u64 + 1)))); + + let b = Tensor::from((0..LEN).map(|i| Value::known(F::from(i as u64 + 1)))); + + let circuit = MyCircuit:: { + inputs: [ValTensor::from(a), ValTensor::from(b)], + _marker: PhantomData, + }; + + let prover = MockProver::run(K as u32, &circuit, vec![]).unwrap(); + prover.assert_satisfied_par(); + } +} + +#[cfg(test)] +mod pow { + use super::*; + + const K: usize = 8; + const LEN: usize = 4; + + #[derive(Clone)] + struct MyCircuit { + inputs: [ValTensor; 1], + _marker: PhantomData, + } + + impl Circuit for MyCircuit { + type Config = BaseConfig; + type FloorPlanner = SimpleFloorPlanner; + type Params = TestParams; + + fn without_witnesses(&self) -> Self { + self.clone() + } + + fn configure(cs: &mut ConstraintSystem) -> Self::Config { + let a = VarTensor::new_advice(cs, K, 1, LEN); + let b = VarTensor::new_advice(cs, K, 1, LEN); + let output = VarTensor::new_advice(cs, K, 1, LEN); + + Self::Config::configure(cs, &[a, b], &output, CheckMode::SAFE) + } + + fn synthesize( + &self, + mut config: Self::Config, + mut layouter: impl Layouter, + ) -> Result<(), Error> { + layouter + .assign_region( + || "", + |region| { + let mut region = RegionCtx::new(region, 0, 1); + config + .layout(&mut region, &self.inputs.clone(), Box::new(PolyOp::Pow(5))) + .map_err(|_| Error::Synthesis) + }, + ) + .unwrap(); + Ok(()) + } + } + + #[test] + fn powcircuit() { + // parameters + let a = Tensor::from((0..LEN).map(|i| Value::known(F::from(i as u64 + 1)))); + + let circuit = MyCircuit:: { + inputs: [ValTensor::from(a)], + _marker: PhantomData, + }; + + let prover = MockProver::run(K as u32, &circuit, vec![]).unwrap(); + prover.assert_satisfied_par(); + } +} + +#[cfg(test)] +mod pack { + use super::*; + + const K: usize = 8; + const LEN: usize = 4; + + #[derive(Clone)] + struct MyCircuit { + inputs: [ValTensor; 1], + _marker: PhantomData, + } + + impl Circuit for MyCircuit { + type Config = BaseConfig; + type FloorPlanner = SimpleFloorPlanner; + type Params = TestParams; + + fn without_witnesses(&self) -> Self { + self.clone() + } + + fn configure(cs: &mut ConstraintSystem) -> Self::Config { + let a = VarTensor::new_advice(cs, K, 1, LEN); + let b = VarTensor::new_advice(cs, K, 1, LEN); + let output = VarTensor::new_advice(cs, K, 1, LEN); + + Self::Config::configure(cs, &[a, b], &output, CheckMode::SAFE) + } + + fn synthesize( + &self, + mut config: Self::Config, + mut layouter: impl Layouter, + ) -> Result<(), Error> { + layouter + .assign_region( + || "", + |region| { + let mut region = RegionCtx::new(region, 0, 1); + config + .layout( + &mut region, + &self.inputs.clone(), + Box::new(PolyOp::Pack(2, 1)), + ) + .map_err(|_| Error::Synthesis) + }, + ) + .unwrap(); + Ok(()) + } + } + + #[test] + fn packcircuit() { + // parameters + let a = Tensor::from((0..LEN).map(|i| Value::known(F::from(i as u64 + 1)))); + + let circuit = MyCircuit:: { + inputs: [ValTensor::from(a)], + _marker: PhantomData, + }; + + let prover = MockProver::run(K as u32, &circuit, vec![]).unwrap(); + prover.assert_satisfied_par(); + } +} + +#[cfg(test)] +mod matmul_relu { + use super::*; + + const K: usize = 18; + const LEN: usize = 32; + use crate::circuit::lookup::LookupOp; + + #[derive(Clone)] + struct MyCircuit { + inputs: [ValTensor; 2], + _marker: PhantomData, + } + + // A columnar ReLu MLP + #[derive(Clone)] + struct MyConfig { + base_config: BaseConfig, + } + + impl Circuit for MyCircuit { + type Config = MyConfig; + type FloorPlanner = SimpleFloorPlanner; + type Params = TestParams; + + fn without_witnesses(&self) -> Self { + self.clone() + } + + fn configure(cs: &mut ConstraintSystem) -> Self::Config { + let a = VarTensor::new_advice(cs, K, 1, LEN); + let b = VarTensor::new_advice(cs, K, 1, LEN); + let output = VarTensor::new_advice(cs, K, 1, LEN); + + let mut base_config = + BaseConfig::configure(cs, &[a.clone(), b.clone()], &output, CheckMode::SAFE); + // sets up a new relu table + base_config + .configure_lookup(cs, &b, &output, &a, (-32768, 32768), K, &LookupOp::ReLU) + .unwrap(); + + MyConfig { base_config } + } + + fn synthesize( + &self, + mut config: Self::Config, + mut layouter: impl Layouter, + ) -> Result<(), Error> { + config.base_config.layout_tables(&mut layouter).unwrap(); + layouter.assign_region( + || "", + |region| { + let mut region = RegionCtx::new(region, 0, 1); + let op = PolyOp::Einsum { + equation: "ij,jk->ik".to_string(), + }; + let output = config + .base_config + .layout(&mut region, &self.inputs, Box::new(op)) + .unwrap(); + let _output = config + .base_config + .layout(&mut region, &[output.unwrap()], Box::new(LookupOp::ReLU)) + .unwrap(); + Ok(()) + }, + )?; + + Ok(()) + } + } + + #[test] + fn matmulrelucircuit() { + // parameters + let mut a = Tensor::from((0..LEN * LEN).map(|_| Value::known(F::from(1)))); + a.reshape(&[LEN, LEN]).unwrap(); + + // parameters + let mut b = Tensor::from((0..LEN).map(|_| Value::known(F::from(1)))); + b.reshape(&[LEN, 1]).unwrap(); + + let circuit = MyCircuit { + inputs: [ValTensor::from(a), ValTensor::from(b)], + _marker: PhantomData, + }; + + let prover = MockProver::run(K as u32, &circuit, vec![]).unwrap(); + prover.assert_satisfied_par(); + } +} + +#[cfg(test)] +mod rangecheckpercent { + use crate::circuit::Tolerance; + use crate::{circuit, tensor::Tensor}; + use halo2_proofs::{ + circuit::{Layouter, SimpleFloorPlanner, Value}, + dev::MockProver, + plonk::{Circuit, ConstraintSystem, Error}, + }; + + const RANGE: f32 = 1.0; // 1 percent error tolerance + const K: usize = 18; + const LEN: usize = 1; + const SCALE: usize = i128::pow(2, 7) as usize; + + use super::*; + + #[derive(Clone)] + struct MyCircuit { + input: ValTensor, + output: ValTensor, + _marker: PhantomData, + } + + impl Circuit for MyCircuit { + type Config = BaseConfig; + type FloorPlanner = SimpleFloorPlanner; + type Params = TestParams; + + fn without_witnesses(&self) -> Self { + self.clone() + } + + fn configure(cs: &mut ConstraintSystem) -> Self::Config { + let scale = utils::F32(SCALE.pow(2) as f32); + let a = VarTensor::new_advice(cs, K, 1, LEN); + let b = VarTensor::new_advice(cs, K, 1, LEN); + let output = VarTensor::new_advice(cs, K, 1, LEN); + let mut config = + Self::Config::configure(cs, &[a.clone(), b.clone()], &output, CheckMode::SAFE); + // set up a new GreaterThan and Recip tables + let nl = &LookupOp::GreaterThan { + a: circuit::utils::F32((RANGE * scale.0) / 100.0), + }; + config + .configure_lookup(cs, &b, &output, &a, (-32768, 32768), K, nl) + .unwrap(); + config + .configure_lookup( + cs, + &b, + &output, + &a, + (-32768, 32768), + K, + &LookupOp::Recip { scale }, + ) + .unwrap(); + config + } + + fn synthesize( + &self, + mut config: Self::Config, + mut layouter: impl Layouter, + ) -> Result<(), Error> { + config.layout_tables(&mut layouter).unwrap(); + layouter + .assign_region( + || "", + |region| { + let mut region = RegionCtx::new(region, 0, 1); + config + .layout( + &mut region, + &[self.output.clone(), self.input.clone()], + Box::new(HybridOp::RangeCheck(Tolerance { + val: RANGE, + scale: SCALE.into(), + })), + ) + .map_err(|_| Error::Synthesis) + }, + ) + .unwrap(); + Ok(()) + } + } + + #[test] + #[allow(clippy::assertions_on_constants)] + fn test_range_check_percent() { + // Successful cases + { + let inp = Tensor::new(Some(&[Value::::known(F::from(100_u64))]), &[1]).unwrap(); + let out = Tensor::new(Some(&[Value::::known(F::from(101_u64))]), &[1]).unwrap(); + let circuit = MyCircuit:: { + input: ValTensor::from(inp), + output: ValTensor::from(out), + _marker: PhantomData, + }; + let prover = MockProver::run(K as u32, &circuit, vec![]).unwrap(); + prover.assert_satisfied_par(); + } + { + let inp = Tensor::new(Some(&[Value::::known(F::from(200_u64))]), &[1]).unwrap(); + let out = Tensor::new(Some(&[Value::::known(F::from(199_u64))]), &[1]).unwrap(); + let circuit = MyCircuit:: { + input: ValTensor::from(inp), + output: ValTensor::from(out), + _marker: PhantomData, + }; + let prover = MockProver::run(K as u32, &circuit, vec![]).unwrap(); + prover.assert_satisfied_par(); + } + + // Unsuccessful case + { + let inp = Tensor::new(Some(&[Value::::known(F::from(100_u64))]), &[1]).unwrap(); + let out = Tensor::new(Some(&[Value::::known(F::from(102_u64))]), &[1]).unwrap(); + let circuit = MyCircuit:: { + input: ValTensor::from(inp), + output: ValTensor::from(out), + _marker: PhantomData, + }; + let prover = MockProver::run(K as u32, &circuit, vec![]).unwrap(); + match prover.verify() { + Ok(_) => { + assert!(false) + } + Err(_) => { + assert!(true) + } + } + } + } +} + +#[cfg(test)] +mod relu { + use super::*; + use halo2_proofs::{ + circuit::{Layouter, SimpleFloorPlanner, Value}, + dev::MockProver, + plonk::{Circuit, ConstraintSystem, Error}, + }; + + #[derive(Clone)] + struct ReLUCircuit { + pub input: ValTensor, + } + + impl Circuit for ReLUCircuit { + type Config = BaseConfig; + type FloorPlanner = SimpleFloorPlanner; + type Params = TestParams; + + fn without_witnesses(&self) -> Self { + self.clone() + } + + fn configure(cs: &mut ConstraintSystem) -> Self::Config { + let advices = (0..3) + .map(|_| VarTensor::new_advice(cs, 4, 1, 3)) + .collect::>(); + + let nl = LookupOp::ReLU; + + let mut config = BaseConfig::default(); + + config + .configure_lookup(cs, &advices[0], &advices[1], &advices[2], (-6, 6), 4, &nl) + .unwrap(); + config + } + + fn synthesize( + &self, + mut config: Self::Config, + mut layouter: impl Layouter, // layouter is our 'write buffer' for the circuit + ) -> Result<(), Error> { + config.layout_tables(&mut layouter).unwrap(); + layouter + .assign_region( + || "", + |region| { + let mut region = RegionCtx::new(region, 0, 1); + config + .layout(&mut region, &[self.input.clone()], Box::new(LookupOp::ReLU)) + .map_err(|_| Error::Synthesis) + }, + ) + .unwrap(); + + Ok(()) + } + } + + #[test] + fn relucircuit() { + let input: Tensor> = + Tensor::new(Some(&[Value::::known(F::from(1_u64)); 4]), &[4]).unwrap(); + + let circuit = ReLUCircuit:: { + input: ValTensor::from(input), + }; + + let prover = MockProver::run(4_u32, &circuit, vec![]).unwrap(); + prover.assert_satisfied_par(); + } +} + +#[cfg(test)] +#[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))] +mod lookup_ultra_overflow { + use super::*; + use halo2_proofs::{ + circuit::{Layouter, SimpleFloorPlanner, Value}, + plonk::{Circuit, ConstraintSystem, Error}, + poly::commitment::ParamsProver, + }; + + #[derive(Clone)] + struct ReLUCircuit { + pub input: ValTensor, + } + + impl Circuit for ReLUCircuit { + type Config = BaseConfig; + type FloorPlanner = SimpleFloorPlanner; + type Params = TestParams; + + fn without_witnesses(&self) -> Self { + self.clone() + } + + fn configure(cs: &mut ConstraintSystem) -> Self::Config { + let advices = (0..3) + .map(|_| VarTensor::new_advice(cs, 4, 1, 3)) + .collect::>(); + + let nl = LookupOp::ReLU; + + let mut config = BaseConfig::default(); + + config + .configure_lookup( + cs, + &advices[0], + &advices[1], + &advices[2], + (-1024, 1024), + 4, + &nl, + ) + .unwrap(); + config + } + + fn synthesize( + &self, + mut config: Self::Config, + mut layouter: impl Layouter, // layouter is our 'write buffer' for the circuit + ) -> Result<(), Error> { + config.layout_tables(&mut layouter).unwrap(); + layouter + .assign_region( + || "", + |region| { + let mut region = RegionCtx::new(region, 0, 1); + config + .layout(&mut region, &[self.input.clone()], Box::new(LookupOp::ReLU)) + .map_err(|_| Error::Synthesis) + }, + ) + .unwrap(); + + Ok(()) + } + } + + #[test] + #[ignore] + fn relucircuit() { + // get some logs fam + crate::logger::init_logger(); + // parameters + let a = Tensor::from((0..4).map(|i| Value::known(F::from(i + 1)))); + + let circuit = ReLUCircuit:: { + input: ValTensor::from(a), + }; + + let params = crate::pfsys::srs::gen_srs::< + halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme<_>, + >(4_u32); + + let pk = crate::pfsys::create_keys::< + halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme, + F, + ReLUCircuit, + >(&circuit, ¶ms) + .unwrap(); + + let prover = crate::pfsys::create_proof_circuit_kzg( + circuit.clone(), + ¶ms, + None, + &pk, + crate::pfsys::TranscriptType::EVM, + halo2_proofs::poly::kzg::strategy::SingleStrategy::new(¶ms), + // use safe mode to verify that the proof is correct + CheckMode::SAFE, + None, + ); + + assert!(prover.is_ok()); + + let proof = prover.unwrap(); + + let strategy = + halo2_proofs::poly::kzg::strategy::SingleStrategy::new(params.verifier_params()); + let vk = pk.get_vk(); + let result = + crate::pfsys::verify_proof_circuit_kzg(params.verifier_params(), proof, vk, strategy); + + assert!(result.is_ok()); + + println!("done."); + } +} + +#[cfg(test)] +mod softmax { + + use super::*; + use halo2_proofs::{ + circuit::{Layouter, SimpleFloorPlanner, Value}, + dev::MockProver, + plonk::{Circuit, ConstraintSystem, Error}, + }; + + const K: usize = 18; + const LEN: usize = 3; + const SCALE: f32 = 128.0; + + #[derive(Clone)] + struct SoftmaxCircuit { + pub input: ValTensor, + _marker: PhantomData, + } + + impl Circuit for SoftmaxCircuit { + type Config = BaseConfig; + type FloorPlanner = SimpleFloorPlanner; + type Params = TestParams; + + fn without_witnesses(&self) -> Self { + self.clone() + } + fn configure(cs: &mut ConstraintSystem) -> Self::Config { + let a = VarTensor::new_advice(cs, K, 1, LEN); + let b = VarTensor::new_advice(cs, K, 1, LEN); + let output = VarTensor::new_advice(cs, K, 1, LEN); + let mut config = Self::Config::configure(cs, &[a, b], &output, CheckMode::SAFE); + let advices = (0..3) + .map(|_| VarTensor::new_advice(cs, K, 1, LEN)) + .collect::>(); + + config + .configure_lookup( + cs, + &advices[0], + &advices[1], + &advices[2], + (-32768, 32768), + K, + &LookupOp::Exp { + scale: SCALE.into(), + }, + ) + .unwrap(); + config + .configure_lookup( + cs, + &advices[0], + &advices[1], + &advices[2], + (-32768, 32768), + K, + &LookupOp::Recip { + scale: SCALE.powf(2.0).into(), + }, + ) + .unwrap(); + config + } + + fn synthesize( + &self, + mut config: Self::Config, + mut layouter: impl Layouter, + ) -> Result<(), Error> { + config.layout_tables(&mut layouter).unwrap(); + layouter + .assign_region( + || "", + |region| { + let mut region = RegionCtx::new(region, 0, 1); + let _output = config + .layout( + &mut region, + &[self.input.clone()], + Box::new(HybridOp::Softmax { + scale: SCALE.into(), + axes: vec![0], + }), + ) + .unwrap(); + Ok(()) + }, + ) + .unwrap(); + + Ok(()) + } + } + + #[test] + fn softmax_circuit() { + let input = Tensor::from((0..LEN).map(|i| Value::known(F::from(i as u64 + 1)))); + + let circuit = SoftmaxCircuit:: { + input: ValTensor::from(input), + _marker: PhantomData, + }; + let prover = MockProver::run(K as u32, &circuit, vec![]).unwrap(); + prover.assert_satisfied_par(); + } +} diff --git a/mnist_ezkl/src/circuit/utils.rs b/mnist_ezkl/src/circuit/utils.rs new file mode 100644 index 0000000..c55ef4f --- /dev/null +++ b/mnist_ezkl/src/circuit/utils.rs @@ -0,0 +1,163 @@ +use serde::{Deserialize, Serialize}; + +// -------------------------------------------------------------------------------------------- +// +// Float Utils to enable the usage of f32s as the keys of HashMaps +// This section is taken from the `eq_float` crate verbatim -- but we also implement deserialization methods +// +// + +use std::cmp::Ordering; +use std::fmt; +use std::hash::{Hash, Hasher}; + +#[derive(Debug, Default, Clone, Copy)] +/// f32 wrapper +pub struct F32(pub f32); + +impl<'de> Deserialize<'de> for F32 { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + let float = f32::deserialize(deserializer)?; + Ok(F32(float)) + } +} + +impl Serialize for F32 { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + f32::serialize(&self.0, serializer) + } +} + +/// This works like `PartialEq` on `f32`, except that `NAN == NAN` is true. +impl PartialEq for F32 { + fn eq(&self, other: &Self) -> bool { + if self.0.is_nan() && other.0.is_nan() { + true + } else { + self.0 == other.0 + } + } +} + +impl Eq for F32 {} + +/// This works like `PartialOrd` on `f32`, except that `NAN` sorts below all other floats +/// (and is equal to another NAN). This always returns a `Some`. +impl PartialOrd for F32 { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +/// This works like `PartialOrd` on `f32`, except that `NAN` sorts below all other floats +/// (and is equal to another NAN). +impl Ord for F32 { + fn cmp(&self, other: &Self) -> Ordering { + self.0.partial_cmp(&other.0).unwrap_or_else(|| { + if self.0.is_nan() && !other.0.is_nan() { + Ordering::Less + } else if !self.0.is_nan() && other.0.is_nan() { + Ordering::Greater + } else { + Ordering::Equal + } + }) + } +} + +impl Hash for F32 { + fn hash(&self, state: &mut H) { + if self.0.is_nan() { + 0x7fc00000u32.hash(state); // a particular bit representation for NAN + } else if self.0 == 0.0 { + // catches both positive and negative zero + 0u32.hash(state); + } else { + self.0.to_bits().hash(state); + } + } +} + +impl From for f32 { + fn from(f: F32) -> Self { + f.0 + } +} + +impl From for F32 { + fn from(f: f32) -> Self { + F32(f) + } +} + +impl From for F32 { + fn from(f: f64) -> Self { + F32(f as f32) + } +} + +impl From for F32 { + fn from(f: usize) -> Self { + F32(f as f32) + } +} + +impl From for f64 { + fn from(f: F32) -> Self { + f.0 as f64 + } +} + +impl From<&F32> for f64 { + fn from(f: &F32) -> Self { + f.0 as f64 + } +} + +impl fmt::Display for F32 { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + self.0.fmt(f) + } +} + +#[cfg(test)] +mod tests { + use std::collections::hash_map::DefaultHasher; + use std::hash::{Hash, Hasher}; + + use super::F32; + + fn calculate_hash(t: &T) -> u64 { + let mut s = DefaultHasher::new(); + t.hash(&mut s); + s.finish() + } + + #[test] + fn f32_eq() { + assert!(F32(std::f32::NAN) == F32(std::f32::NAN)); + assert!(F32(std::f32::NAN) != F32(5.0)); + assert!(F32(5.0) != F32(std::f32::NAN)); + assert!(F32(0.0) == F32(-0.0)); + } + + #[test] + fn f32_cmp() { + assert!(F32(std::f32::NAN) == F32(std::f32::NAN)); + assert!(F32(std::f32::NAN) < F32(5.0)); + assert!(F32(5.0) > F32(std::f32::NAN)); + assert!(F32(0.0) == F32(-0.0)); + } + + #[test] + fn f32_hash() { + assert!(calculate_hash(&F32(0.0)) == calculate_hash(&F32(-0.0))); + assert!(calculate_hash(&F32(std::f32::NAN)) == calculate_hash(&F32(-std::f32::NAN))); + } +} diff --git a/mnist_ezkl/src/commands.rs b/mnist_ezkl/src/commands.rs new file mode 100644 index 0000000..157f895 --- /dev/null +++ b/mnist_ezkl/src/commands.rs @@ -0,0 +1,767 @@ +use clap::{Parser, Subcommand, ValueEnum}; +#[cfg(not(target_arch = "wasm32"))] +use ethers::types::H160; +#[cfg(feature = "python-bindings")] +use pyo3::{ + conversion::{FromPyObject, PyTryFrom}, + exceptions::PyValueError, + prelude::*, + types::PyString, +}; +use serde::{Deserialize, Serialize}; +use std::error::Error; +use std::path::PathBuf; + +use crate::{pfsys::ProofType, RunArgs}; + +use crate::circuit::CheckMode; +#[cfg(not(target_arch = "wasm32"))] +use crate::graph::TestDataSource; +use crate::pfsys::TranscriptType; + +impl std::fmt::Display for TranscriptType { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.to_possible_value() + .expect("no values are skipped") + .get_name() + .fmt(f) + } +} +#[cfg(feature = "python-bindings")] +/// Converts TranscriptType into a PyObject (Required for TranscriptType to be compatible with Python) +impl IntoPy for TranscriptType { + fn into_py(self, py: Python) -> PyObject { + match self { + TranscriptType::Poseidon => "poseidon".to_object(py), + TranscriptType::EVM => "evm".to_object(py), + } + } +} +#[cfg(feature = "python-bindings")] +/// Obtains TranscriptType from PyObject (Required for TranscriptType to be compatible with Python) +impl<'source> FromPyObject<'source> for TranscriptType { + fn extract(ob: &'source PyAny) -> PyResult { + let trystr = ::try_from(ob)?; + let strval = trystr.to_string(); + match strval.to_lowercase().as_str() { + "poseidon" => Ok(TranscriptType::Poseidon), + "evm" => Ok(TranscriptType::EVM), + _ => Err(PyValueError::new_err("Invalid value for TranscriptType")), + } + } +} + +#[derive(Debug, Copy, Clone, Serialize, Deserialize, PartialEq, PartialOrd)] +/// Determines what the calibration pass should optimize for +pub enum CalibrationTarget { + /// Optimizes for reducing cpu and memory usage + Resources { + /// Whether to allow for column overflow. This can reduce memory usage (eg. for a browser environment), but may result in a verifier that doesn't fit on the blockchain. + col_overflow: bool, + }, + /// Optimizes for numerical accuracy given the fixed point representation + Accuracy, +} + +impl Default for CalibrationTarget { + fn default() -> Self { + CalibrationTarget::Resources { + col_overflow: false, + } + } +} + +impl ToString for CalibrationTarget { + fn to_string(&self) -> String { + match self { + CalibrationTarget::Resources { col_overflow: true } => { + "resources/col-overflow".to_string() + } + CalibrationTarget::Resources { + col_overflow: false, + } => "resources".to_string(), + CalibrationTarget::Accuracy => "accuracy".to_string(), + } + } +} + +impl From<&str> for CalibrationTarget { + fn from(s: &str) -> Self { + match s { + "resources" => CalibrationTarget::Resources { + col_overflow: false, + }, + "resources/col-overflow" => CalibrationTarget::Resources { col_overflow: true }, + "accuracy" => CalibrationTarget::Accuracy, + _ => { + log::error!("Invalid value for CalibrationTarget"); + log::warn!("Defaulting to resources"); + CalibrationTarget::default() + } + } + } +} + +#[cfg(feature = "python-bindings")] +/// Converts CalibrationTarget into a PyObject (Required for CalibrationTarget to be compatible with Python) +impl IntoPy for CalibrationTarget { + fn into_py(self, py: Python) -> PyObject { + match self { + CalibrationTarget::Resources { col_overflow: true } => { + "resources/col-overflow".to_object(py) + } + CalibrationTarget::Resources { + col_overflow: false, + } => "resources".to_object(py), + CalibrationTarget::Accuracy => "accuracy".to_object(py), + } + } +} + +#[cfg(feature = "python-bindings")] +/// Obtains CalibrationTarget from PyObject (Required for CalibrationTarget to be compatible with Python) +impl<'source> FromPyObject<'source> for CalibrationTarget { + fn extract(ob: &'source PyAny) -> PyResult { + let trystr = ::try_from(ob)?; + let strval = trystr.to_string(); + match strval.to_lowercase().as_str() { + "resources" => Ok(CalibrationTarget::Resources { + col_overflow: false, + }), + "resources/col-overflow" => Ok(CalibrationTarget::Resources { col_overflow: true }), + "accuracy" => Ok(CalibrationTarget::Accuracy), + _ => Err(PyValueError::new_err("Invalid value for CalibrationTarget")), + } + } +} + +use lazy_static::lazy_static; + +// if CARGO VERSION is 0.0.0 replace with "source - no compatibility guaranteed" +lazy_static! { + /// The version of the ezkl library + pub static ref VERSION: &'static str = if env!("CARGO_PKG_VERSION") == "0.0.0" { + "source - no compatibility guaranteed" + } else { + env!("CARGO_PKG_VERSION") + }; +} + +#[allow(missing_docs)] +#[derive(Parser, Debug, Clone, Deserialize, Serialize)] +#[command(author, about, long_about = None)] +#[clap(version = *VERSION)] +pub struct Cli { + #[command(subcommand)] + #[allow(missing_docs)] + pub command: Commands, +} + +impl Cli { + /// Export the ezkl configuration as json + pub fn as_json(&self) -> Result> { + let serialized = match serde_json::to_string(&self) { + Ok(s) => s, + Err(e) => { + return Err(Box::new(e)); + } + }; + Ok(serialized) + } + /// Parse an ezkl configuration from a json + pub fn from_json(arg_json: &str) -> Result { + serde_json::from_str(arg_json) + } +} + +#[allow(missing_docs)] +#[derive(Debug, Subcommand, Clone, Deserialize, Serialize, PartialEq, PartialOrd)] +pub enum Commands { + Empty, + /// Loads model and prints model table + #[command(arg_required_else_help = true)] + Table { + /// The path to the .onnx model file + #[arg(short = 'M', long)] + model: PathBuf, + /// proving arguments + #[clap(flatten)] + args: RunArgs, + }, + + #[cfg(feature = "render")] + /// Renders the model circuit to a .png file. For an overview of how to interpret these plots, see https://zcash.github.io/halo2/user/dev-tools.html + #[command(arg_required_else_help = true)] + RenderCircuit { + /// The path to the .onnx model file + #[arg(short = 'M', long)] + model: PathBuf, + /// Path to save the .png circuit render + #[arg(short = 'O', long)] + output: PathBuf, + /// proving arguments + #[clap(flatten)] + args: RunArgs, + }, + + /// Generates the witness from an input file. + #[command(arg_required_else_help = true)] + GenWitness { + /// The path to the .json data file + #[arg(short = 'D', long)] + data: PathBuf, + /// The path to the compiled model file + #[arg(short = 'M', long)] + compiled_circuit: PathBuf, + /// Path to the witness (public and private inputs) .json file + #[arg(short = 'O', long, default_value = "witness.json")] + output: PathBuf, + /// Path to the witness (public and private inputs) .json file (optional - solely used to generate kzg commits) + #[arg(short = 'V', long)] + vk_path: Option, + /// Path to the srs file (optional - solely used to generate kzg commits) + #[arg(short = 'P', long)] + srs_path: Option, + }, + + /// Produces the proving hyperparameters, from run-args + #[command(arg_required_else_help = true)] + GenSettings { + /// The path to the .onnx model file + #[arg(short = 'M', long)] + model: PathBuf, + /// Path to circuit_settings file to output + #[arg(short = 'O', long, default_value = "settings.json")] + settings_path: PathBuf, + /// proving arguments + #[clap(flatten)] + args: RunArgs, + }, + + /// Calibrates the proving scale, lookup bits and logrows from a circuit settings file. + #[cfg(not(target_arch = "wasm32"))] + #[command(arg_required_else_help = true)] + CalibrateSettings { + /// The path to the .onnx model file + #[arg(short = 'M', long)] + model: PathBuf, + /// Path to circuit_settings file to read in AND overwrite. + #[arg(short = 'O', long, default_value = "settings.json")] + settings_path: PathBuf, + /// The path to the .json calibration data file. + #[arg(short = 'D', long = "data")] + data: PathBuf, + #[arg(long = "target", default_value = "resources")] + /// Target for calibration. + target: CalibrationTarget, + /// Optional scales to specifically try for calibration. + #[arg(long, value_delimiter = ',', allow_hyphen_values = true)] + scales: Option>, + /// max logrows to use for calibration, 26 is the max public SRS size + #[arg(long)] + max_logrows: Option, + }, + + /// Generates a dummy SRS + #[command(name = "gen-srs", arg_required_else_help = true)] + GenSrs { + /// The path to output to the desired srs file + #[arg(long, default_value = "kzg.srs")] + srs_path: PathBuf, + /// number of logrows to use for srs + #[arg(long)] + logrows: usize, + }, + + #[cfg(not(target_arch = "wasm32"))] + /// Gets an SRS from a circuit settings file. + #[command(name = "get-srs", arg_required_else_help = true)] + GetSrs { + /// The path to output to the desired srs file + #[arg(long, default_value = "kzg.srs")] + srs_path: PathBuf, + /// Path to circuit_settings file to read in. Overrides logrows if specified. + #[arg(short = 'S', long, default_value = None)] + settings_path: Option, + /// Number of logrows to use for srs. To manually override the logrows, omit specifying the settings_path + #[arg(long, default_value = None)] + logrows: Option, + /// Check mode for srs. verifies downloaded srs is valid. set to unsafe for speed. + #[arg(long, default_value = "safe")] + check: CheckMode, + }, + /// Loads model and input and runs mock prover (for testing) + #[command(arg_required_else_help = true)] + Mock { + /// The path to the .json witness file + #[arg(short = 'W', long)] + witness: PathBuf, + /// The path to the .onnx model file + #[arg(short = 'M', long)] + model: PathBuf, + }, + + /// Mock aggregate proofs + #[command(arg_required_else_help = true)] + MockAggregate { + /// The path to the snarks to aggregate over + #[arg(long)] + aggregation_snarks: Vec, + /// logrows used for aggregation circuit + #[arg(long)] + logrows: u32, + /// whether the accumulated are segments of a larger proof + #[arg(long, default_value = "false")] + split_proofs: bool, + }, + + /// setup aggregation circuit :) + #[command(arg_required_else_help = true)] + SetupAggregate { + /// The path to samples of snarks that will be aggregated over + #[arg(long)] + sample_snarks: Vec, + /// The path to save the desired verification key file + #[arg(long, default_value = "vk_aggr.key")] + vk_path: PathBuf, + /// The path to save the desired proving key file + #[arg(long, default_value = "pk_aggr.key")] + pk_path: PathBuf, + /// The path to SRS + #[arg(long)] + srs_path: PathBuf, + /// logrows used for aggregation circuit + #[arg(long)] + logrows: u32, + /// whether the accumulated are segments of a larger proof + #[arg(long, default_value = "false")] + split_proofs: bool, + }, + /// Aggregates proofs :) + #[command(arg_required_else_help = true)] + Aggregate { + /// The path to the snarks to aggregate over + #[arg(long)] + aggregation_snarks: Vec, + /// The path to load the desired proving key file + #[arg(long)] + pk_path: PathBuf, + /// The path to the desired output file + #[arg(long, default_value = "proof_aggr.proof")] + proof_path: PathBuf, + /// The path to SRS + #[arg(long)] + srs_path: PathBuf, + #[arg( + long, + require_equals = true, + num_args = 0..=1, + default_value_t = TranscriptType::EVM, + value_enum + )] + transcript: TranscriptType, + /// logrows used for aggregation circuit + #[arg(long)] + logrows: u32, + /// run sanity checks during calculations (safe or unsafe) + #[arg(long, default_value = "safe")] + check_mode: CheckMode, + /// whether the accumulated are segments of a larger proof + #[arg(long, default_value = "false")] + split_proofs: bool, + }, + /// Compiles a circuit from onnx to a simplified graph (einsum + other ops) and parameters as sets of field elements + #[command(arg_required_else_help = true)] + CompileCircuit { + /// The path to the .onnx model file + #[arg(short = 'M', long)] + model: PathBuf, + /// The path to output the processed model + #[arg(long)] + compiled_circuit: PathBuf, + /// The path to load circuit params from + #[arg(short = 'S', long)] + settings_path: PathBuf, + }, + /// Creates pk and vk + #[command(arg_required_else_help = true)] + Setup { + /// The path to the compiled model file + #[arg(short = 'M', long)] + compiled_circuit: PathBuf, + /// The srs path + #[arg(long)] + srs_path: PathBuf, + /// The path to output the verification key file + #[arg(long, default_value = "vk.key")] + vk_path: PathBuf, + /// The path to output the proving key file + #[arg(long, default_value = "pk.key")] + pk_path: PathBuf, + /// The graph witness (optional - used to override fixed values in the circuit) + #[arg(short = 'W', long)] + witness: Option, + }, + + #[cfg(not(target_arch = "wasm32"))] + /// Fuzzes the proof pipeline with random inputs, random parameters, and random keys + #[command(arg_required_else_help = true)] + Fuzz { + /// The path to the .json witness file, which should include both the network input (possibly private) and the network output (public input to the proof) + #[arg(short = 'W', long)] + witness: PathBuf, + /// The path to the processed model file + #[arg(short = 'M', long)] + compiled_circuit: PathBuf, + #[arg( + long, + require_equals = true, + num_args = 0..=1, + default_value_t = TranscriptType::EVM, + value_enum + )] + transcript: TranscriptType, + /// number of fuzz iterations + #[arg(long, default_value = "10")] + num_runs: usize, + }, + #[cfg(not(target_arch = "wasm32"))] + SetupTestEVMData { + /// The path to the .json data file, which should include both the network input (possibly private) and the network output (public input to the proof) + #[arg(short = 'D', long)] + data: PathBuf, + /// The path to the compiled model file + #[arg(short = 'M', long)] + compiled_circuit: PathBuf, + /// For testing purposes only. The optional path to the .json data file that will be generated that contains the OnChain data storage information + /// derived from the file information in the data .json file. + /// Should include both the network input (possibly private) and the network output (public input to the proof) + #[arg(short = 'T', long)] + test_data: PathBuf, + /// RPC URL for an Ethereum node, if None will use Anvil but WON'T persist state + #[arg(short = 'U', long)] + rpc_url: Option, + /// where does the input data come from + #[arg(long, default_value = "on-chain")] + input_source: TestDataSource, + /// where does the output data come from + #[arg(long, default_value = "on-chain")] + output_source: TestDataSource, + }, + #[cfg(not(target_arch = "wasm32"))] + TestUpdateAccountCalls { + /// The path to verfier contract's address + #[arg(long)] + addr: H160, + /// The path to the .json data file, which should include both the network input (possibly private) and the network output (public input to the proof) + #[arg(short = 'D', long)] + data: PathBuf, + /// RPC URL for an Ethereum node, if None will use Anvil but WON'T persist state + #[arg(short = 'U', long)] + rpc_url: Option, + }, + #[cfg(not(target_arch = "wasm32"))] + /// Swaps the positions in the transcript that correspond to commitments + #[command(arg_required_else_help = true)] + SwapProofCommitments { + /// The path to the proof file + #[arg(short = 'P', long)] + proof_path: PathBuf, + /// The path to the witness file + #[arg(short = 'W', long)] + witness_path: PathBuf, + }, + + #[cfg(not(target_arch = "wasm32"))] + /// Loads model, data, and creates proof + #[command(arg_required_else_help = true)] + Prove { + /// The path to the .json witness file, which should include both the network input (possibly private) and the network output (public input to the proof) + #[arg(short = 'W', long)] + witness: PathBuf, + /// The path to the compiled model file + #[arg(short = 'M', long)] + compiled_circuit: PathBuf, + /// The path to load the desired proving key file + #[arg(long)] + pk_path: PathBuf, + /// The path to the desired output file + #[arg(long, default_value = "proof.proof")] + proof_path: PathBuf, + /// The parameter path + #[arg(long)] + srs_path: PathBuf, + #[arg( + long, + require_equals = true, + num_args = 0..=1, + default_value_t = ProofType::Single, + value_enum + )] + proof_type: ProofType, + /// run sanity checks during calculations (safe or unsafe) + #[arg(long, default_value = "safe")] + check_mode: CheckMode, + }, + #[cfg(not(target_arch = "wasm32"))] + /// Creates an EVM verifier for a single proof + #[command(name = "create-evm-verifier", arg_required_else_help = true)] + CreateEVMVerifier { + /// The path to load the desired params file + #[arg(long)] + srs_path: PathBuf, + /// The path to load circuit settings from + #[arg(short = 'S', long)] + settings_path: PathBuf, + /// The path to load the desired verification key file + #[arg(long)] + vk_path: PathBuf, + /// The path to output the Solidity code + #[arg(long, default_value = "evm_deploy.sol")] + sol_code_path: PathBuf, + /// The path to output the Solidity verifier ABI + #[arg(long, default_value = "verifier_abi.json")] + abi_path: PathBuf, + }, + #[cfg(not(target_arch = "wasm32"))] + /// Creates an EVM verifier that attests to on-chain inputs for a single proof + #[command(name = "create-evm-da", arg_required_else_help = true)] + CreateEVMDataAttestation { + /// The path to load the desired srs file from + #[arg(long)] + srs_path: PathBuf, + /// The path to load circuit settings from + #[arg(short = 'S', long)] + settings_path: PathBuf, + /// The path to load the desired verification key file + #[arg(long)] + vk_path: PathBuf, + /// The path to output the Solidity code + #[arg(long, default_value = "evm_da_deploy.sol")] + sol_code_path: PathBuf, + /// The path to output the Solidity verifier ABI + #[arg(long, default_value = "verifier_da_abi.json")] + abi_path: PathBuf, + /// The path to the .json data file, which should + /// contain the necessary calldata and accoount addresses + /// needed need to read from all the on-chain + /// view functions that return the data that the network + /// ingests as inputs. + #[arg(short = 'D', long)] + data: PathBuf, + // todo, optionally allow supplying proving key + }, + + #[cfg(not(target_arch = "wasm32"))] + /// Creates an EVM verifier for an aggregate proof + #[command(name = "create-evm-verifier-aggr", arg_required_else_help = true)] + CreateEVMVerifierAggr { + /// The path to load the desired srs file from + #[arg(long)] + srs_path: PathBuf, + /// The path to output to load the desired verification key file + #[arg(long)] + vk_path: PathBuf, + /// The path to the Solidity code + #[arg(long, default_value = "evm_deploy_aggr.sol")] + sol_code_path: PathBuf, + /// The path to output the Solidity verifier ABI + #[arg(long, default_value = "verifier_aggr_abi.json")] + abi_path: PathBuf, + // aggregated circuit settings paths, used to calculate the number of instances in the aggregate proof + #[arg(long)] + aggregation_settings: Vec, + }, + /// Verifies a proof, returning accept or reject + #[command(arg_required_else_help = true)] + Verify { + /// The path to load circuit params from + #[arg(short = 'S', long)] + settings_path: PathBuf, + /// The path to the proof file + #[arg(long)] + proof_path: PathBuf, + /// The path to output the desired verification key file (optional) + #[arg(long)] + vk_path: PathBuf, + /// The kzg srs path + #[arg(long)] + srs_path: PathBuf, + }, + /// Verifies an aggregate proof, returning accept or reject + #[command(arg_required_else_help = true)] + VerifyAggr { + /// The path to the proof file + #[arg(long)] + proof_path: PathBuf, + /// The path to output the desired verification key file (optional) + #[arg(long)] + vk_path: PathBuf, + /// The srs path + #[arg(long)] + srs_path: PathBuf, + /// logrows used for aggregation circuit + #[arg(long)] + logrows: u32, + }, + #[cfg(not(target_arch = "wasm32"))] + DeployEvmVerifier { + /// The path to the Solidity code + #[arg(long)] + sol_code_path: PathBuf, + /// RPC URL for an Ethereum node, if None will use Anvil but WON'T persist state + #[arg(short = 'U', long)] + rpc_url: Option, + #[arg(long, default_value = "contract.address")] + /// The path to output the contract address + addr_path: PathBuf, + /// The optimizer runs to set on the verifier. (Lower values optimize for deployment, while higher values optimize for execution) + #[arg(long, default_value = "1")] + optimizer_runs: usize, + /// Private secp256K1 key in hex format, 64 chars, no 0x prefix, of the account signing transactions. If None the private key will be generated by Anvil + #[arg(short = 'P', long)] + private_key: Option, + }, + #[cfg(not(target_arch = "wasm32"))] + #[command(name = "deploy-evm-da", arg_required_else_help = true)] + DeployEvmDataAttestation { + /// The path to the .json data file, which should include both the network input (possibly private) and the network output (public input to the proof) + #[arg(short = 'D', long)] + data: PathBuf, + /// The path to load circuit params from + #[arg(long)] + settings_path: PathBuf, + /// The path to the Solidity code + #[arg(long)] + sol_code_path: PathBuf, + /// RPC URL for an Ethereum node, if None will use Anvil but WON'T persist state + #[arg(short = 'U', long)] + rpc_url: Option, + #[arg(long, default_value = "contract_da.address")] + /// The path to output the contract address + addr_path: PathBuf, + /// The optimizer runs to set on the verifier. (Lower values optimize for deployment, while higher values optimize for execution) + #[arg(long, default_value = "1")] + optimizer_runs: usize, + /// Private secp256K1 key in hex format, 64 chars, no 0x prefix, of the account signing transactions. If None the private key will be generated by Anvil + #[arg(short = 'P', long)] + private_key: Option, + }, + #[cfg(not(target_arch = "wasm32"))] + /// Verifies a proof using a local EVM executor, returning accept or reject + #[command(name = "verify-evm", arg_required_else_help = true)] + VerifyEVM { + /// The path to the proof file + #[arg(long)] + proof_path: PathBuf, + /// The path to verfier contract's address + #[arg(long)] + addr_verifier: H160, + /// RPC URL for an Ethereum node, if None will use Anvil but WON'T persist state + #[arg(short = 'U', long)] + rpc_url: Option, + /// does the verifier use data attestation ? + #[arg(long)] + addr_da: Option, + }, + + /// Print the proof in hexadecimal + #[command(name = "print-proof-hex", arg_required_else_help = true)] + PrintProofHex { + /// The path to the proof file + #[arg(long)] + proof_path: PathBuf, + }, + + /// Gets credentials from the hub + #[command(name = "get-hub-credentials", arg_required_else_help = true)] + #[cfg(not(target_arch = "wasm32"))] + GetHubCredentials { + /// The user's api key + #[arg(short = 'K', long)] + api_key: Option, + /// The path to the model file + #[arg(short = 'N', long)] + username: String, + /// The path to the input json file + #[arg(short = 'U', long)] + url: Option, + }, + + /// Create artifacts and deploys them on the hub + #[command(name = "create-hub-artifact", arg_required_else_help = true)] + #[cfg(not(target_arch = "wasm32"))] + CreateHubArtifact { + /// The user's api key + #[arg(short = 'K', long)] + api_key: Option, + /// The path to the model file + #[arg(short = 'M', long)] + uncompiled_circuit: PathBuf, + /// The path to the input json file + #[arg(short = 'D', long)] + data: PathBuf, + /// the hub's url + #[arg(short = 'O', long)] + organization_id: String, + ///artifact name + #[arg(short = 'A', long)] + artifact_name: String, + /// the hub's url + #[arg(short = 'U', long)] + url: Option, + /// proving arguments + #[clap(flatten)] + args: RunArgs, + /// calibration target + #[arg(long, default_value = "resources")] + target: CalibrationTarget, + }, + + #[command(name = "get-hub-artifact", arg_required_else_help = true)] + #[cfg(not(target_arch = "wasm32"))] + GetHubArtifact { + /// The user's api key + #[arg(short = 'K', long)] + api_key: Option, + /// The artifact id + #[arg(short = 'A', long)] + artifact_id: String, + /// The url to send requests to + #[arg(short = 'U', long)] + url: Option, + }, + + /// Prove data on the hub + #[command(name = "prove-hub", arg_required_else_help = true)] + #[cfg(not(target_arch = "wasm32"))] + ProveHub { + /// The user's api key + #[arg(short = 'K', long)] + api_key: Option, + /// The path to the model file + #[arg(short = 'A', long)] + artifact_id: String, + /// The path to the input json file + #[arg(short = 'D', long)] + data: PathBuf, + /// The url to send requests to + #[arg(short = 'U', long)] + url: Option, + }, + + /// Create artifacts and deploys them on the hub + #[command(name = "get-hub-proof", arg_required_else_help = true)] + #[cfg(not(target_arch = "wasm32"))] + GetHubProof { + /// The user's api key + #[arg(short = 'K', long)] + api_key: Option, + /// The proof id + #[arg(short = 'P', long)] + proof_id: String, + /// The url to send requests to + #[arg(short = 'U', long)] + url: Option, + }, +} diff --git a/mnist_ezkl/src/eth.rs b/mnist_ezkl/src/eth.rs new file mode 100644 index 0000000..258ba49 --- /dev/null +++ b/mnist_ezkl/src/eth.rs @@ -0,0 +1,751 @@ +use crate::graph::input::{CallsToAccount, FileSourceInner, GraphData}; +use crate::graph::modules::{ELGAMAL_INSTANCES, POSEIDON_INSTANCES}; +use crate::graph::DataSource; +#[cfg(not(target_arch = "wasm32"))] +use crate::graph::GraphSettings; +use crate::pfsys::evm::EvmVerificationError; +use crate::pfsys::Snark; +use ethers::abi::Contract; +use ethers::contract::abigen; +use ethers::contract::ContractFactory; +use ethers::core::k256::ecdsa::SigningKey; +use ethers::middleware::SignerMiddleware; +use ethers::prelude::ContractInstance; +#[cfg(target_arch = "wasm32")] +use ethers::prelude::Wallet; +use ethers::providers::Middleware; +use ethers::providers::{Http, Provider}; +use ethers::signers::Signer; +use ethers::solc::{CompilerInput, Solc}; +use ethers::types::transaction::eip2718::TypedTransaction; +use ethers::types::TransactionRequest; +use ethers::types::H160; +use ethers::types::U256; +use ethers::types::{Bytes, I256}; +#[cfg(not(target_arch = "wasm32"))] +use ethers::{ + prelude::{LocalWallet, Wallet}, + utils::{Anvil, AnvilInstance}, +}; +use halo2_solidity_verifier::encode_calldata; +use halo2curves::bn256::{Fr, G1Affine}; +use halo2curves::group::ff::PrimeField; +use log::{debug, info, warn}; +use std::error::Error; +use std::path::PathBuf; +#[cfg(not(target_arch = "wasm32"))] +use std::time::Duration; +use std::{convert::TryFrom, sync::Arc}; + +/// A local ethers-rs based client +pub type EthersClient = Arc, Wallet>>; + +// Generate contract bindings OUTSIDE the functions so they are part of library +abigen!(TestReads, "./abis/TestReads.json"); +abigen!(DataAttestation, "./abis/DataAttestation.json"); +abigen!(QuantizeData, "./abis/QuantizeData.json"); + +const TESTREADS_SOL: &str = include_str!("../contracts/TestReads.sol"); +const QUANTIZE_DATA_SOL: &str = include_str!("../contracts/QuantizeData.sol"); +const ATTESTDATA_SOL: &str = include_str!("../contracts/AttestData.sol"); +const LOADINSTANCES_SOL: &str = include_str!("../contracts/LoadInstances.sol"); + +/// Return an instance of Anvil and a client for the given RPC URL. If none is provided, a local client is used. +#[cfg(not(target_arch = "wasm32"))] +pub async fn setup_eth_backend( + rpc_url: Option<&str>, + private_key: Option<&str>, +) -> Result<(AnvilInstance, EthersClient), Box> { + // Launch anvil + let anvil = Anvil::new() + .args(["--code-size-limit=41943040", "--disable-block-gas-limit"]) + .spawn(); + + let endpoint: String; + if let Some(rpc_url) = rpc_url { + endpoint = rpc_url.to_string(); + } else { + endpoint = anvil.endpoint(); + }; + + // Connect to the network + let provider = Provider::::try_from(endpoint)?.interval(Duration::from_millis(10u64)); + + let chain_id = provider.get_chainid().await?.as_u64(); + info!("using chain {}", chain_id); + + // Instantiate the wallet + let wallet: LocalWallet; + if let Some(private_key) = private_key { + debug!("using private key {}", private_key); + // Sanity checks for private_key + let private_key_format_error = + "Private key must be in hex format, 64 chars, without 0x prefix"; + if private_key.len() != 64 { + return Err(private_key_format_error.into()); + } + let private_key_buffer = hex::decode(private_key)?; + let signing_key = SigningKey::from_slice(&private_key_buffer)?; + wallet = LocalWallet::from(signing_key); + } else { + wallet = anvil.keys()[0].clone().into(); + } + + // Instantiate the client with the signer + let client = Arc::new(SignerMiddleware::new( + provider, + wallet.with_chain_id(chain_id), + )); + + Ok((anvil, client)) +} + +/// +pub async fn deploy_verifier_via_solidity( + sol_code_path: PathBuf, + rpc_url: Option<&str>, + runs: usize, + private_key: Option<&str>, +) -> Result> { + // anvil instance must be alive at least until the factory completes the deploy + let (anvil, client) = setup_eth_backend(rpc_url, private_key).await?; + + let (abi, bytecode, runtime_bytecode) = + get_contract_artifacts(sol_code_path, "Halo2Verifier", runs)?; + + let factory = get_sol_contract_factory(abi, bytecode, runtime_bytecode, client.clone())?; + let contract = factory.deploy(())?.send().await?; + let addr = contract.address(); + + drop(anvil); + Ok(addr) +} + +/// +pub async fn deploy_da_verifier_via_solidity( + settings_path: PathBuf, + input: PathBuf, + sol_code_path: PathBuf, + rpc_url: Option<&str>, + runs: usize, + private_key: Option<&str>, +) -> Result> { + let (anvil, client) = setup_eth_backend(rpc_url, private_key).await?; + + let input = GraphData::from_path(input)?; + + let settings = GraphSettings::load(&settings_path)?; + + let mut scales: Vec = vec![]; + // The data that will be stored in the test contracts that will eventually be read from. + let mut calls_to_accounts = vec![]; + + let mut instance_shapes = vec![]; + let mut model_instance_offset = 0; + + if settings.run_args.input_visibility.is_hashed() { + instance_shapes.push(POSEIDON_INSTANCES) + } else if settings.run_args.input_visibility.is_encrypted() { + instance_shapes.push(ELGAMAL_INSTANCES) + } else if settings.run_args.input_visibility.is_public() { + for idx in 0..settings.model_input_scales.len() { + let shape = &settings.model_instance_shapes[idx]; + instance_shapes.push(shape.iter().product::()); + model_instance_offset += 1; + } + } + + if settings.run_args.param_visibility.is_hashed() + || settings.run_args.param_visibility.is_encrypted() + { + return Err(Box::new(EvmVerificationError::InvalidVisibility)); + } + + if settings.run_args.output_visibility.is_hashed() { + instance_shapes.push(POSEIDON_INSTANCES) + } else if settings.run_args.output_visibility.is_encrypted() { + instance_shapes.push(ELGAMAL_INSTANCES) + } else if settings.run_args.output_visibility.is_public() { + for idx in model_instance_offset..model_instance_offset + settings.model_output_scales.len() + { + let shape = &settings.model_instance_shapes[idx]; + instance_shapes.push(shape.iter().product::()); + } + } + + println!("instance_shapes: {:#?}", instance_shapes); + + let mut instance_idx = 0; + let mut contract_instance_offset = 0; + + if let DataSource::OnChain(source) = input.input_data { + if settings.run_args.input_visibility.is_hashed_public() + | settings.run_args.input_visibility.is_encrypted() + { + // set scales 1.0 + scales.extend(vec![0; instance_shapes[instance_idx]]); + instance_idx += 1; + } else { + let input_scales = settings.model_input_scales; + // give each input a scale + for scale in input_scales { + scales.extend(vec![scale as u32; instance_shapes[instance_idx]]); + instance_idx += 1; + } + } + for call in source.calls { + calls_to_accounts.push(call); + } + } else if let DataSource::File(source) = input.input_data { + if settings.run_args.input_visibility.is_public() { + instance_idx += source.len(); + for s in source { + contract_instance_offset += s.len(); + } + } + } + + if let Some(DataSource::OnChain(source)) = input.output_data { + if settings.run_args.output_visibility.is_hashed_public() + | settings.run_args.output_visibility.is_encrypted() + { + // set scales 1.0 + scales.extend(vec![0; instance_shapes[instance_idx]]); + } else { + let input_scales = settings.model_output_scales; + // give each output a scale + for scale in input_scales { + scales.extend(vec![scale as u32; instance_shapes[instance_idx]]); + instance_idx += 1; + } + } + for call in source.calls { + calls_to_accounts.push(call); + } + } + + let (contract_addresses, call_data, decimals) = if !calls_to_accounts.is_empty() { + parse_calls_to_accounts(calls_to_accounts)? + } else { + return Err("Data source for either input_data or output_data must be OnChain".into()); + }; + + let (abi, bytecode, runtime_bytecode) = + get_contract_artifacts(sol_code_path, "DataAttestation", runs)?; + let factory = get_sol_contract_factory(abi, bytecode, runtime_bytecode, client.clone())?; + + info!("call_data: {:#?}", call_data); + info!("contract_addresses: {:#?}", contract_addresses); + info!("decimals: {:#?}", decimals); + + let contract = factory + .deploy(( + contract_addresses, + call_data, + decimals, + scales, + contract_instance_offset as u32, + client.address(), + ))? + .send() + .await?; + + drop(anvil); + Ok(contract.address()) +} + +type ParsedCallsToAccount = (Vec, Vec>, Vec>); + +fn parse_calls_to_accounts( + calls_to_accounts: Vec, +) -> Result> { + let mut contract_addresses = vec![]; + let mut call_data = vec![]; + let mut decimals: Vec> = vec![]; + for (i, val) in calls_to_accounts.iter().enumerate() { + let contract_address_bytes = hex::decode(val.address.clone())?; + let contract_address = H160::from_slice(&contract_address_bytes); + contract_addresses.push(contract_address); + call_data.push(vec![]); + decimals.push(vec![]); + for (call, decimal) in &val.call_data { + let call_data_bytes = hex::decode(call)?; + call_data[i].push(ethers::types::Bytes::from(call_data_bytes)); + decimals[i].push(ethers::types::U256::from_dec_str(&decimal.to_string())?); + } + } + Ok((contract_addresses, call_data, decimals)) +} + +pub async fn update_account_calls( + addr: H160, + input: PathBuf, + rpc_url: Option<&str>, +) -> Result<(), Box> { + let input = GraphData::from_path(input)?; + + // The data that will be stored in the test contracts that will eventually be read from. + let mut calls_to_accounts = vec![]; + + if let DataSource::OnChain(source) = input.input_data { + for call in source.calls { + calls_to_accounts.push(call); + } + } + + if let Some(DataSource::OnChain(source)) = input.output_data { + for call in source.calls { + calls_to_accounts.push(call); + } + } + + let (contract_addresses, call_data, decimals) = if !calls_to_accounts.is_empty() { + parse_calls_to_accounts(calls_to_accounts)? + } else { + return Err("Data source for either input_data or output_data must be OnChain".into()); + }; + + let (anvil, client) = setup_eth_backend(rpc_url, None).await?; + + let contract = DataAttestation::new(addr, client.clone()); + + contract + .update_account_calls( + contract_addresses.clone(), + call_data.clone(), + decimals.clone(), + ) + .send() + .await?; + + // Instantiate a different wallet + let wallet: LocalWallet = anvil.keys()[1].clone().into(); + + let client = Arc::new(client.with_signer(wallet.with_chain_id(anvil.chain_id()))); + + // update contract signer with non admin account + let contract = DataAttestation::new(addr, client.clone()); + + // call to update_account_calls should fail + + if (contract + .update_account_calls(contract_addresses, call_data, decimals) + .send() + .await) + .is_err() + { + info!("update_account_calls failed as expected"); + } else { + return Err("update_account_calls should have failed".into()); + } + + Ok(()) +} + +/// Verify a proof using a Solidity verifier contract +#[cfg(not(target_arch = "wasm32"))] +pub async fn verify_proof_via_solidity( + proof: Snark, + addr: ethers::types::Address, + rpc_url: Option<&str>, +) -> Result> { + let flattened_instances = proof.instances.into_iter().flatten(); + + let encoded = encode_calldata(None, &proof.proof, &flattened_instances.collect::>()); + + info!("encoded: {:#?}", hex::encode(&encoded)); + let (anvil, client) = setup_eth_backend(rpc_url, None).await?; + let tx: TypedTransaction = TransactionRequest::default() + .to(addr) + .from(client.address()) + .data(encoded) + .into(); + debug!("transaction {:#?}", tx); + + let result = client.call(&tx, None).await; + + if result.is_err() { + return Err(Box::new(EvmVerificationError::SolidityExecution)); + } + let result = result?; + info!("result: {:#?}", result.to_vec()); + // decode return bytes value into uint8 + let result = result.to_vec().last().ok_or("no contract ouput")? == &1u8; + if !result { + return Err(Box::new(EvmVerificationError::InvalidProof)); + } + + let gas = client.estimate_gas(&tx, None).await?; + + info!("estimated verify gas cost: {:#?}", gas); + + // if gas is greater than 30 million warn the user that the gas cost is above ethereum's 30 million block gas limit + if gas > 30_000_000.into() { + warn!( + "Gas cost of verify transaction is greater than 30 million block gas limit. It will fail on mainnet." + ); + } else if gas > 15_000_000.into() { + warn!( + "Gas cost of verify transaction is greater than 15 million, the target block size for ethereum" + ); + } + + drop(anvil); + Ok(true) +} + +fn count_decimal_places(num: f32) -> usize { + // Convert the number to a string + let s = num.to_string(); + + // Find the decimal point + match s.find('.') { + Some(index) => { + // Count the number of characters after the decimal point + s[index + 1..].len() + } + None => 0, + } +} + +/// +pub async fn setup_test_contract( + client: Arc, + data: &[Vec], +) -> Result<(ContractInstance, M>, Vec), Box> { + // save the abi to a tmp file + let mut sol_path = std::env::temp_dir(); + sol_path.push("testreads.sol"); + std::fs::write(&sol_path, TESTREADS_SOL)?; + + // Compile the contract + let (abi, bytecode, runtime_bytecode) = get_contract_artifacts(sol_path, "TestReads", 0)?; + + let factory = get_sol_contract_factory(abi, bytecode, runtime_bytecode, client.clone())?; + + let mut decimals = vec![]; + let mut scaled_by_decimals_data = vec![]; + for input in &data[0] { + if input.is_float() { + let input = input.to_float() as f32; + let decimal_places = count_decimal_places(input) as u8; + let scaled_by_decimals = input * f32::powf(10., decimal_places.into()); + scaled_by_decimals_data.push(I256::from(scaled_by_decimals as i128)); + decimals.push(decimal_places); + } else if input.is_field() { + let input = input.to_field(0); + let hex_str_fr = format!("{:?}", input); + scaled_by_decimals_data.push(I256::from_raw(U256::from_str_radix(&hex_str_fr, 16)?)); + decimals.push(0); + } + } + + let contract = factory.deploy(scaled_by_decimals_data)?.send().await?; + Ok((contract, decimals)) +} + +/// Verify a proof using a Solidity DataAttestation contract. +/// Used for testing purposes. +#[cfg(not(target_arch = "wasm32"))] +pub async fn verify_proof_with_data_attestation( + proof: Snark, + addr_verifier: ethers::types::Address, + addr_da: ethers::types::Address, + rpc_url: Option<&str>, +) -> Result> { + use ethers::abi::{Function, Param, ParamType, StateMutability, Token}; + + let mut public_inputs: Vec = vec![]; + let flattened_instances = proof.instances.into_iter().flatten(); + + for val in flattened_instances.clone() { + let bytes = val.to_repr(); + let u = U256::from_little_endian(bytes.as_slice()); + public_inputs.push(u); + } + + let encoded_verifier = + encode_calldata(None, &proof.proof, &flattened_instances.collect::>()); + + info!("encoded: {:#?}", hex::encode(&encoded_verifier)); + + info!("public_inputs: {:#?}", public_inputs); + info!( + "proof: {:#?}", + ethers::types::Bytes::from(proof.proof.to_vec()) + ); + + #[allow(deprecated)] + let func = Function { + name: "verifyWithDataAttestation".to_owned(), + inputs: vec![ + Param { + name: "verifier".to_owned(), + kind: ParamType::Address, + internal_type: None, + }, + Param { + name: "encoded".to_owned(), + kind: ParamType::Bytes, + internal_type: None, + }, + ], + outputs: vec![Param { + name: "success".to_owned(), + kind: ParamType::Bool, + internal_type: None, + }], + constant: None, + state_mutability: StateMutability::View, + }; + + let encoded = func.encode_input(&[ + Token::Address(addr_verifier), + Token::Bytes(encoded_verifier), + ])?; + + info!("encoded: {:#?}", hex::encode(&encoded)); + let (anvil, client) = setup_eth_backend(rpc_url, None).await?; + let tx: TypedTransaction = TransactionRequest::default() + .to(addr_da) + .from(client.address()) + .data(encoded) + .into(); + debug!("transaction {:#?}", tx); + info!( + "estimated verify gas cost: {:#?}", + client.estimate_gas(&tx, None).await? + ); + + let result = client.call(&tx, None).await; + if result.is_err() { + return Err(Box::new(EvmVerificationError::SolidityExecution)); + } + let result = result?; + info!("result: {:#?}", result); + // decode return bytes value into uint8 + let result = result.to_vec().last().ok_or("no contract ouput")? == &1u8; + if !result { + return Err(Box::new(EvmVerificationError::InvalidProof)); + } + drop(anvil); + Ok(true) +} + +/// get_provider returns a JSON RPC HTTP Provider +pub fn get_provider(rpc_url: &str) -> Result, Box> { + let provider = Provider::::try_from(rpc_url)?; + debug!("{:#?}", provider); + Ok(provider) +} + +/// Tests on-chain data storage by deploying a contract that stores the network input and or output +/// data in its storage. It does this by converting the floating point values to integers and storing the +/// the number of decimals of the floating point value on chain. +pub async fn test_on_chain_data( + client: Arc, + data: &[Vec], +) -> Result, Box> { + let (contract, decimals) = setup_test_contract(client.clone(), data).await?; + + let contract = TestReads::new(contract.address(), client.clone()); + + // Get the encoded call data for each input + let mut calldata = vec![]; + for (i, _) in data.iter().flatten().enumerate() { + let function = contract.method::<_, I256>("arr", i as u32)?; + let call = function.calldata().ok_or("could not get calldata")?; + // Push (call, decimals) to the calldata vector. + calldata.push((hex::encode(call), decimals[i])); + } + // Instantiate a new CallsToAccount struct + let calls_to_account = CallsToAccount { + call_data: calldata, + address: hex::encode(contract.address().as_bytes()), + }; + info!("calls_to_account: {:#?}", calls_to_account); + Ok(vec![calls_to_account]) +} + +/// Reads on-chain inputs, returning the raw encoded data returned from making all the calls in on_chain_input_data +#[cfg(not(target_arch = "wasm32"))] +pub async fn read_on_chain_inputs( + client: Arc, + address: H160, + data: &Vec, +) -> Result<(Vec, Vec), Box> { + // Iterate over all on-chain inputs + let mut fetched_inputs = vec![]; + let mut decimals = vec![]; + for on_chain_data in data { + // Construct the address + let contract_address_bytes = hex::decode(on_chain_data.address.clone())?; + let contract_address = H160::from_slice(&contract_address_bytes); + for (call_data, decimal) in &on_chain_data.call_data { + let call_data_bytes = hex::decode(call_data.clone())?; + let tx: TypedTransaction = TransactionRequest::default() + .to(contract_address) + .from(address) + .data(call_data_bytes) + .into(); + debug!("transaction {:#?}", tx); + + let result = client.call(&tx, None).await?; + debug!("return data {:#?}", result); + fetched_inputs.push(result); + decimals.push(*decimal); + } + } + Ok((fetched_inputs, decimals)) +} + +/// +#[cfg(not(target_arch = "wasm32"))] +pub async fn evm_quantize( + client: Arc, + scales: Vec, + data: &(Vec, Vec), +) -> Result, Box> { + // save the sol to a tmp file + let mut sol_path = std::env::temp_dir(); + sol_path.push("quantizedata.sol"); + std::fs::write(&sol_path, QUANTIZE_DATA_SOL)?; + + let (abi, bytecode, runtime_bytecode) = get_contract_artifacts(sol_path, "QuantizeData", 0)?; + let factory = get_sol_contract_factory(abi, bytecode, runtime_bytecode, client.clone())?; + + let contract = factory.deploy(())?.send().await?; + + let contract = QuantizeData::new(contract.address(), client.clone()); + + let fetched_inputs = data.0.clone(); + let decimals = data.1.clone(); + + let fetched_inputs = fetched_inputs + .iter() + .map(|x| Result::<_, std::convert::Infallible>::Ok(ethers::types::Bytes::from(x.to_vec()))) + .collect::, _>>()?; + + let decimals = decimals + .iter() + .map(|x| U256::from_dec_str(&x.to_string())) + .collect::, _>>()?; + + let scales = scales + .iter() + .map(|x| U256::from_dec_str(&x.to_string())) + .collect::, _>>()?; + + info!("scales: {:#?}", scales); + info!("decimals: {:#?}", decimals); + info!("fetched_inputs: {:#?}", fetched_inputs); + + let results = contract + .quantize_data(fetched_inputs, decimals, scales) + .call() + .await?; + + let felts = contract.to_field_element(results.clone()).call().await?; + info!("evm quantization contract results: {:#?}", felts,); + + let results = felts + .iter() + .map(|x| PrimeField::from_str_vartime(&x.to_string()).unwrap()) + .collect::>(); + info!("evm quantization results: {:#?}", results,); + Ok(results.to_vec()) +} + +/// Generates the contract factory for a solidity verifier, optionally compiling the code with optimizer runs set on the Solc compiler. +fn get_sol_contract_factory( + abi: Contract, + bytecode: Bytes, + runtime_bytecode: Bytes, + client: Arc, +) -> Result, Box> { + const MAX_RUNTIME_BYTECODE_SIZE: usize = 24577; + let size = runtime_bytecode.len(); + debug!("runtime bytecode size: {:#?}", size); + if size > MAX_RUNTIME_BYTECODE_SIZE { + // `_runtime_bytecode` exceeds the limit + warn!( + "Solidity runtime bytecode size is: {:#?}, + which exceeds 24577 bytes spurious dragon limit. + Contract will fail to deploy on any chain with + EIP 140 enabled", + size + ); + } + Ok(ContractFactory::new(abi, bytecode, client)) +} + +/// Compiles a solidity verifier contract and returns the abi, bytecode, and runtime bytecode +#[cfg(not(target_arch = "wasm32"))] +pub fn get_contract_artifacts( + sol_code_path: PathBuf, + contract_name: &str, + runs: usize, +) -> Result<(Contract, Bytes, Bytes), Box> { + if !sol_code_path.exists() { + return Err("sol_code_path does not exist".into()); + } + // Create the compiler input, enabling the optimizer and setting the optimzer runs. + let input: CompilerInput = if runs > 0 { + let mut i = CompilerInput::new(sol_code_path)?[0] + .clone() + .optimizer(runs); + i.settings.optimizer.enable(); + i + } else { + CompilerInput::new(sol_code_path)?[0].clone() + }; + let compiled = Solc::default().compile(&input)?; + + let (abi, bytecode, runtime_bytecode) = match compiled.find(contract_name) { + Some(c) => c.into_parts_or_default(), + None => { + return Err("could not find contract".into()); + } + }; + Ok((abi, bytecode, runtime_bytecode)) +} + +/// Sets the constants stored in the da verifier +pub fn fix_da_sol( + input_data: Option>, + output_data: Option>, +) -> Result> { + let mut accounts_len = 0; + let mut contract = ATTESTDATA_SOL.to_string(); + let load_instances = LOADINSTANCES_SOL.to_string(); + // replace the import statement with the load_instances contract, not including the + // `SPDX-License-Identifier: MIT pragma solidity ^0.8.20;` at the top of the file + contract = contract.replace( + "import './LoadInstances.sol';", + &load_instances[load_instances + .find("contract") + .ok_or("could not get load-instances contract")?..], + ); + + // fill in the quantization params and total calls + // as constants to the contract to save on gas + if let Some(input_data) = input_data { + let input_calls: usize = input_data.iter().map(|v| v.call_data.len()).sum(); + accounts_len = input_data.len(); + contract = contract.replace( + "uint256 constant INPUT_CALLS = 0;", + &format!("uint256 constant INPUT_CALLS = {};", input_calls), + ); + } + if let Some(output_data) = output_data { + let output_calls: usize = output_data.iter().map(|v| v.call_data.len()).sum(); + accounts_len += output_data.len(); + contract = contract.replace( + "uint256 constant OUTPUT_CALLS = 0;", + &format!("uint256 constant OUTPUT_CALLS = {};", output_calls), + ); + } + contract = contract.replace("AccountCall[]", &format!("AccountCall[{}]", accounts_len)); + + Ok(contract) +} diff --git a/mnist_ezkl/src/execute.rs b/mnist_ezkl/src/execute.rs new file mode 100644 index 0000000..63215d7 --- /dev/null +++ b/mnist_ezkl/src/execute.rs @@ -0,0 +1,2125 @@ +use crate::circuit::CheckMode; +#[cfg(not(target_arch = "wasm32"))] +use crate::commands::CalibrationTarget; +use crate::commands::Commands; +#[cfg(not(target_arch = "wasm32"))] +use crate::eth::{deploy_da_verifier_via_solidity, deploy_verifier_via_solidity}; +#[cfg(not(target_arch = "wasm32"))] +use crate::eth::{fix_da_sol, get_contract_artifacts, verify_proof_via_solidity}; +use crate::graph::input::GraphData; +use crate::graph::{GraphCircuit, GraphSettings, GraphWitness, Model}; +#[cfg(not(target_arch = "wasm32"))] +use crate::graph::{TestDataSource, TestSources}; +use crate::pfsys::evm::aggregation::AggregationCircuit; +#[cfg(not(target_arch = "wasm32"))] +use crate::pfsys::evm::{single::gen_evm_verifier, YulCode}; +use crate::pfsys::{ + create_keys, load_pk, load_vk, save_params, save_pk, swap_proof_commitments_kzg, Snark, + StrategyType, TranscriptType, +}; +use crate::pfsys::{create_proof_circuit_kzg, verify_proof_circuit_kzg}; +use crate::pfsys::{save_vk, srs::*}; +use crate::RunArgs; +#[cfg(not(target_arch = "wasm32"))] +use ethers::types::H160; +use gag::Gag; +use halo2_proofs::dev::VerifyFailure; +use halo2_proofs::poly::commitment::Params; +use halo2_proofs::poly::commitment::ParamsProver; +use halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme; +use halo2_proofs::poly::kzg::strategy::AccumulatorStrategy; +use halo2_proofs::poly::kzg::{ + commitment::ParamsKZG, strategy::SingleStrategy as KZGSingleStrategy, +}; +#[cfg(not(target_arch = "wasm32"))] +use halo2_solidity_verifier; +use halo2curves::bn256::{Bn256, Fr, G1Affine}; +#[cfg(not(target_arch = "wasm32"))] +use halo2curves::ff::Field; +#[cfg(not(target_arch = "wasm32"))] +use indicatif::{ProgressBar, ProgressStyle}; +use instant::Instant; +#[cfg(not(target_arch = "wasm32"))] +use itertools::Itertools; +#[cfg(not(target_arch = "wasm32"))] +use log::debug; +use log::{info, trace}; +#[cfg(feature = "render")] +use plotters::prelude::*; +#[cfg(not(target_arch = "wasm32"))] +use rand::Rng; +#[cfg(not(target_arch = "wasm32"))] +use rayon::prelude::{IntoParallelIterator, ParallelIterator}; +use std::error::Error; +use std::fs::File; +#[cfg(not(target_arch = "wasm32"))] +use std::io::{Cursor, Write}; +use std::path::{Path, PathBuf}; +#[cfg(not(target_arch = "wasm32"))] +use std::process::Command; +#[cfg(not(target_arch = "wasm32"))] +use std::sync::atomic::{AtomicBool, AtomicI64, Ordering}; +#[cfg(not(target_arch = "wasm32"))] +use std::sync::OnceLock; +#[cfg(not(target_arch = "wasm32"))] +use std::time::Duration; +use thiserror::Error; +#[cfg(not(target_arch = "wasm32"))] +use tokio_util::codec::{BytesCodec, FramedRead}; + +#[cfg(not(target_arch = "wasm32"))] +static _SOLC_REQUIREMENT: OnceLock = OnceLock::new(); +#[cfg(not(target_arch = "wasm32"))] +fn check_solc_requirement() { + info!("checking solc installation.."); + _SOLC_REQUIREMENT.get_or_init(|| match Command::new("solc").arg("--version").output() { + Ok(output) => { + debug!("solc output: {:#?}", output); + debug!("solc output success: {:#?}", output.status.success()); + if !output.status.success() { + log::error!( + "`solc` check failed: {}", + String::from_utf8_lossy(&output.stderr) + ); + return false; + } + debug!("solc check passed, proceeding"); + true + } + Err(_) => { + log::error!("`solc` check failed: solc not found"); + false + } + }); +} + +/// A wrapper for tensor related errors. +#[derive(Debug, Error)] +pub enum ExecutionError { + /// Shape mismatch in a operation + #[error("verification failed")] + VerifyError(Vec), +} + +/// Run an ezkl command with given args +pub async fn run(command: Commands) -> Result<(), Box> { + match command { + Commands::Empty => Ok(()), + #[cfg(not(target_arch = "wasm32"))] + Commands::Fuzz { + witness, + compiled_circuit, + transcript, + num_runs, + } => fuzz(compiled_circuit, witness, transcript, num_runs), + + Commands::GenSrs { srs_path, logrows } => gen_srs_cmd(srs_path, logrows as u32), + #[cfg(not(target_arch = "wasm32"))] + Commands::GetSrs { + srs_path, + settings_path, + logrows, + check, + } => get_srs_cmd(srs_path, settings_path, logrows, check).await, + Commands::Table { model, args } => table(model, args), + #[cfg(feature = "render")] + Commands::RenderCircuit { + model, + output, + args, + } => render(model, output, args), + Commands::GenSettings { + model, + settings_path, + args, + } => gen_circuit_settings(model, settings_path, args), + #[cfg(not(target_arch = "wasm32"))] + Commands::CalibrateSettings { + model, + settings_path, + data, + target, + scales, + max_logrows, + } => calibrate(model, data, settings_path, target, scales, max_logrows), + Commands::GenWitness { + data, + compiled_circuit, + output, + vk_path, + srs_path, + } => gen_witness(compiled_circuit, data, Some(output), vk_path, srs_path) + .await + .map(|_| ()), + Commands::Mock { model, witness } => mock(model, witness), + #[cfg(not(target_arch = "wasm32"))] + Commands::CreateEVMVerifier { + vk_path, + srs_path, + settings_path, + sol_code_path, + abi_path, + } => create_evm_verifier(vk_path, srs_path, settings_path, sol_code_path, abi_path), + #[cfg(not(target_arch = "wasm32"))] + Commands::CreateEVMDataAttestation { + vk_path, + srs_path, + settings_path, + sol_code_path, + abi_path, + data, + } => create_evm_data_attestation( + vk_path, + srs_path, + settings_path, + sol_code_path, + abi_path, + data, + ), + #[cfg(not(target_arch = "wasm32"))] + Commands::CreateEVMVerifierAggr { + vk_path, + srs_path, + sol_code_path, + abi_path, + aggregation_settings, + } => create_evm_aggregate_verifier( + vk_path, + srs_path, + sol_code_path, + abi_path, + aggregation_settings, + ), + Commands::CompileCircuit { + model, + compiled_circuit, + settings_path, + } => compile_circuit(model, compiled_circuit, settings_path), + Commands::Setup { + compiled_circuit, + srs_path, + vk_path, + pk_path, + witness, + } => setup(compiled_circuit, srs_path, vk_path, pk_path, witness), + #[cfg(not(target_arch = "wasm32"))] + Commands::SetupTestEVMData { + data, + compiled_circuit, + test_data, + rpc_url, + input_source, + output_source, + } => { + setup_test_evm_witness( + data, + compiled_circuit, + test_data, + rpc_url, + input_source, + output_source, + ) + .await + } + #[cfg(not(target_arch = "wasm32"))] + Commands::TestUpdateAccountCalls { + addr, + data, + rpc_url, + } => test_update_account_calls(addr, data, rpc_url).await, + #[cfg(not(target_arch = "wasm32"))] + Commands::SwapProofCommitments { + proof_path, + witness_path, + } => swap_proof_commitments(proof_path, witness_path), + #[cfg(not(target_arch = "wasm32"))] + Commands::Prove { + witness, + compiled_circuit, + pk_path, + proof_path, + srs_path, + proof_type, + check_mode, + } => prove( + witness, + compiled_circuit, + pk_path, + Some(proof_path), + srs_path, + proof_type, + check_mode, + ) + .map(|_| ()), + Commands::MockAggregate { + aggregation_snarks, + logrows, + split_proofs, + } => mock_aggregate(aggregation_snarks, logrows, split_proofs), + Commands::SetupAggregate { + sample_snarks, + vk_path, + pk_path, + srs_path, + logrows, + split_proofs, + } => setup_aggregate( + sample_snarks, + vk_path, + pk_path, + srs_path, + logrows, + split_proofs, + ), + Commands::Aggregate { + proof_path, + aggregation_snarks, + pk_path, + srs_path, + transcript, + logrows, + check_mode, + split_proofs, + } => aggregate( + proof_path, + aggregation_snarks, + pk_path, + srs_path, + transcript, + logrows, + check_mode, + split_proofs, + ), + Commands::Verify { + proof_path, + settings_path, + vk_path, + srs_path, + } => verify(proof_path, settings_path, vk_path, srs_path), + Commands::VerifyAggr { + proof_path, + vk_path, + srs_path, + logrows, + } => verify_aggr(proof_path, vk_path, srs_path, logrows), + #[cfg(not(target_arch = "wasm32"))] + Commands::DeployEvmVerifier { + sol_code_path, + rpc_url, + addr_path, + optimizer_runs, + private_key, + } => { + deploy_evm( + sol_code_path, + rpc_url, + addr_path, + optimizer_runs, + private_key, + ) + .await + } + #[cfg(not(target_arch = "wasm32"))] + Commands::DeployEvmDataAttestation { + data, + settings_path, + sol_code_path, + rpc_url, + addr_path, + optimizer_runs, + private_key, + } => { + deploy_da_evm( + data, + settings_path, + sol_code_path, + rpc_url, + addr_path, + optimizer_runs, + private_key, + ) + .await + } + #[cfg(not(target_arch = "wasm32"))] + Commands::VerifyEVM { + proof_path, + addr_verifier, + rpc_url, + addr_da, + } => verify_evm(proof_path, addr_verifier, rpc_url, addr_da).await, + Commands::PrintProofHex { proof_path } => print_proof_hex(proof_path), + #[cfg(not(target_arch = "wasm32"))] + Commands::GetHubCredentials { + api_key, + username, + url, + } => get_hub_credentials(api_key.as_deref(), url.as_deref(), &username) + .await + .map(|_| ()), + + #[cfg(not(target_arch = "wasm32"))] + Commands::CreateHubArtifact { + api_key, + uncompiled_circuit, + data, + organization_id, + artifact_name, + url, + args, + target, + } => deploy_model( + api_key.as_deref(), + url.as_deref(), + &uncompiled_circuit, + &data, + &artifact_name, + &organization_id, + &args, + &target, + ) + .await + .map(|_| ()), + #[cfg(not(target_arch = "wasm32"))] + Commands::GetHubArtifact { + api_key, + artifact_id, + url, + } => get_deployed_model(api_key.as_deref(), url.as_deref(), &artifact_id) + .await + .map(|_| ()), + #[cfg(not(target_arch = "wasm32"))] + Commands::GetHubProof { + api_key, + proof_id, + url, + } => get_hub_proof(api_key.as_deref(), url.as_deref(), &proof_id) + .await + .map(|_| ()), + #[cfg(not(target_arch = "wasm32"))] + Commands::ProveHub { + api_key, + artifact_id, + data, + url, + } => prove_hub(api_key.as_deref(), url.as_deref(), &artifact_id, &data) + .await + .map(|_| ()), + } +} + +pub(crate) fn gen_srs_cmd(srs_path: PathBuf, logrows: u32) -> Result<(), Box> { + let params = gen_srs::>(logrows); + save_params::>(&srs_path, ¶ms)?; + Ok(()) +} + +#[cfg(not(target_arch = "wasm32"))] +async fn fetch_srs(uri: &str) -> Result, Box> { + let pb = { + let pb = init_spinner(); + pb.set_message("Downloading SRS (this may take a while) ..."); + pb + }; + let client = reqwest::Client::new(); + // wasm doesn't require it to be mutable + #[allow(unused_mut)] + let mut resp = client.get(uri).body(vec![]).send().await?; + let mut buf = vec![]; + while let Some(chunk) = resp.chunk().await? { + buf.extend(chunk.to_vec()); + } + + pb.finish_with_message("SRS downloaded."); + Ok(std::mem::take(&mut buf)) +} + +#[cfg(not(target_arch = "wasm32"))] +pub(crate) async fn get_srs_cmd( + srs_path: PathBuf, + settings_path: Option, + logrows: Option, + check_mode: CheckMode, +) -> Result<(), Box> { + let k = if let Some(settings_p) = settings_path { + if settings_p.exists() { + let settings = GraphSettings::load(&settings_p)?; + settings.run_args.logrows + } else { + let err_string = format!( + "You will need to provide a valid settings file to use the settings option. You should run gen-settings to generate a settings file (and calibrate-settings to pick optimal logrows)." + ); + return Err(err_string.into()); + } + } else if let Some(k) = logrows { + k + } else { + let err_string = format!( + "You will need to provide a settings file or set the logrows. You should run gen-settings to generate a settings file (and calibrate-settings to pick optimal logrows)." + ); + return Err(err_string.into()); + }; + + let srs_uri = format!("{}{}", PUBLIC_SRS_URL, k); + let mut reader = Cursor::new(fetch_srs(&srs_uri).await?); + // check the SRS + if matches!(check_mode, CheckMode::SAFE) { + #[cfg(not(target_arch = "wasm32"))] + let pb = init_spinner(); + #[cfg(not(target_arch = "wasm32"))] + pb.set_message("Validating SRS (this may take a while) ..."); + ParamsKZG::::read(&mut reader)?; + #[cfg(not(target_arch = "wasm32"))] + pb.finish_with_message("SRS validated"); + } + + let mut file = std::fs::File::create(srs_path)?; + file.write_all(reader.get_ref())?; + + info!("SRS downloaded"); + Ok(()) +} + +pub(crate) fn table(model: PathBuf, run_args: RunArgs) -> Result<(), Box> { + let model = Model::from_run_args(&run_args, &model)?; + info!("\n {}", model.table_nodes()); + Ok(()) +} + +pub(crate) async fn gen_witness( + compiled_circuit_path: PathBuf, + data: PathBuf, + output: Option, + vk_path: Option, + srs_path: Option, +) -> Result> { + // these aren't real values so the sanity checks are mostly meaningless + + let mut circuit = GraphCircuit::load(compiled_circuit_path)?; + let data = GraphData::from_path(data)?; + let settings = circuit.settings().clone(); + + let vk = if let Some(vk) = vk_path { + Some(load_vk::, Fr, GraphCircuit>( + vk, + settings.clone(), + )?) + } else { + None + }; + + let srs = if let Some(srs) = srs_path { + Some(load_params_cmd(srs, settings.run_args.logrows)?) + } else { + None + }; + + #[cfg(not(target_arch = "wasm32"))] + let mut input = circuit.load_graph_input(&data).await?; + #[cfg(target_arch = "wasm32")] + let mut input = circuit.load_graph_input(&data)?; + + let start_time = Instant::now(); + + let witness = circuit.forward(&mut input, vk.as_ref(), srs.as_ref())?; + + // print each variable tuple (symbol, value) as symbol=value + trace!( + "witness generation {:?} took {:?}", + circuit + .settings() + .run_args + .variables + .iter() + .map(|v| { format!("{}={}", v.0, v.1) }) + .collect::>(), + start_time.elapsed() + ); + + if let Some(output_path) = output { + serde_json::to_writer(&File::create(output_path)?, &witness)?; + } + Ok(witness) +} + +/// Generate a circuit settings file +pub(crate) fn gen_circuit_settings( + model_path: PathBuf, + params_output: PathBuf, + run_args: RunArgs, +) -> Result<(), Box> { + let circuit = GraphCircuit::from_run_args(&run_args, &model_path)?; + let params = circuit.settings(); + params.save(¶ms_output).map_err(Box::::from) +} + +// not for wasm targets +#[cfg(not(target_arch = "wasm32"))] +pub(crate) fn init_spinner() -> ProgressBar { + let pb = indicatif::ProgressBar::new_spinner(); + pb.set_draw_target(indicatif::ProgressDrawTarget::stdout()); + pb.enable_steady_tick(Duration::from_millis(200)); + pb.set_style( + ProgressStyle::with_template("[{elapsed_precise}] {spinner:.blue} {msg}") + .unwrap() + .tick_strings(&[ + "------ - ✨ ", + "------ - ⏳ ", + "------ - 🌎 ", + "------ - 🔎 ", + "------ - 🥹 ", + "------ - 🫠 ", + "------ - 👾 ", + ]), + ); + pb +} + +// not for wasm targets +#[cfg(not(target_arch = "wasm32"))] +pub(crate) fn init_bar(len: u64) -> ProgressBar { + let pb = ProgressBar::new(len); + pb.set_draw_target(indicatif::ProgressDrawTarget::stdout()); + pb.enable_steady_tick(Duration::from_millis(200)); + let sty = ProgressStyle::with_template( + "[{elapsed_precise}] {bar:40.cyan/blue} {pos:>7}/{len:7} {msg}", + ) + .unwrap() + .progress_chars("##-"); + pb.set_style(sty); + pb +} + +#[cfg(not(target_arch = "wasm32"))] +use colored_json::ToColoredJson; + +/// Calibrate the circuit parameters to a given a dataset +#[cfg(not(target_arch = "wasm32"))] +#[allow(trivial_casts)] +pub(crate) fn calibrate( + model_path: PathBuf, + data: PathBuf, + settings_path: PathBuf, + target: CalibrationTarget, + scales: Option>, + max_logrows: Option, +) -> Result<(), Box> { + let data = GraphData::from_path(data)?; + // load the pre-generated settings + let settings = GraphSettings::load(&settings_path)?; + // now retrieve the run args + // we load the model to get the input and output shapes + // check if gag already exists + + let _r = match Gag::stdout() { + Ok(r) => Some(r), + Err(_) => None, + }; + + let model = Model::from_run_args(&settings.run_args, &model_path)?; + // drop the gag + std::mem::drop(_r); + + let range = if let Some(scales) = scales { + scales + } else { + match target { + CalibrationTarget::Resources { .. } => (8..10).collect::>(), + CalibrationTarget::Accuracy => (10..14).collect::>(), + } + }; + + let chunks = data.split_into_batches(model.graph.input_shapes()?)?; + + info!("num of calibration batches: {}", chunks.len()); + + let mut found_params: Vec = vec![]; + + let scale_rebase_multiplier = [1, 2, 10]; + + // 2 x 2 grid + let range_grid = range + .iter() + .cartesian_product(range.iter()) + .map(|(a, b)| (*a, *b)) + .collect::>(); + + // remove all entries where input_scale > param_scale + let mut range_grid = range_grid + .into_iter() + .filter(|(a, b)| a <= b) + .collect::>(); + + // if all integers + let all_scale_0 = model + .graph + .get_input_types()? + .iter() + .all(|t| t.is_integer()); + if all_scale_0 { + // set all a values to 0 then dedup + range_grid = range_grid + .iter() + .map(|(_, b)| (0, *b)) + .sorted() + .dedup() + .collect::>(); + } + + let range_grid = range_grid + .iter() + .cartesian_product(scale_rebase_multiplier.iter()) + .map(|(a, b)| (*a, *b)) + .collect::>(); + + let pb = init_bar(range_grid.len() as u64); + pb.set_message("calibrating..."); + + for ((input_scale, param_scale), scale_rebase_multiplier) in range_grid { + pb.set_message(format!( + "input scale: {}, param scale: {}, scale rebase multiplier: {}", + input_scale, param_scale, scale_rebase_multiplier + )); + // vec of settings copied chunks.len() times + let run_args_iterable = vec![settings.run_args.clone(); chunks.len()]; + + let _r = match Gag::stdout() { + Ok(r) => Some(r), + Err(_) => None, + }; + let _q = match Gag::stderr() { + Ok(r) => Some(r), + Err(_) => None, + }; + + let tasks = chunks + .iter() + .zip(run_args_iterable) + .map(|(chunk, run_args)| { + // we need to create a new run args for each chunk + // time it + let chunk = chunk.clone(); + let local_run_args = RunArgs { + input_scale, + param_scale, + scale_rebase_multiplier, + ..run_args.clone() + }; + + let original_settings = settings.clone(); + + let mut circuit = match GraphCircuit::from_run_args(&local_run_args, &model_path) { + Ok(c) => c, + Err(_) => { + return Err(format!("failed to create circuit from run args")) + as Result + } + }; + + let data = circuit + .load_graph_from_file_exclusively(&chunk) + .map_err(|e| format!("failed to load circuit inputs: {}", e))?; + + circuit + .calibrate(&data, max_logrows) + .map_err(|e| format!("failed to calibrate: {}", e))?; + + let settings = circuit.settings().clone(); + + let found_run_args = RunArgs { + input_scale: settings.run_args.input_scale, + param_scale: settings.run_args.param_scale, + lookup_range: settings.run_args.lookup_range, + logrows: settings.run_args.logrows, + scale_rebase_multiplier: settings.run_args.scale_rebase_multiplier, + ..run_args.clone() + }; + + let found_settings = GraphSettings { + run_args: found_run_args, + required_lookups: settings.required_lookups, + model_output_scales: settings.model_output_scales, + model_input_scales: settings.model_input_scales, + num_rows: settings.num_rows, + total_assignments: settings.total_assignments, + total_const_size: settings.total_const_size, + ..original_settings.clone() + }; + + Ok(found_settings) as Result + }) + .collect::>>(); + + let mut res: Vec = vec![]; + for task in tasks { + if let Ok(task) = task { + res.push(task); + } + } + + // drop the gag + std::mem::drop(_r); + std::mem::drop(_q); + + let max_lookup_range = res + .iter() + .map(|x| x.run_args.lookup_range.1) + .max() + .unwrap_or(0); + let min_lookup_range = res + .iter() + .map(|x| x.run_args.lookup_range.0) + .min() + .unwrap_or(0); + + if let Some(mut best) = res.into_iter().max_by_key(|p| { + ( + p.run_args.logrows, + p.run_args.input_scale, + p.run_args.param_scale, + ) + }) { + best.run_args.lookup_range = (min_lookup_range, max_lookup_range); + // pick the one with the largest logrows + found_params.push(best.clone()); + debug!( + "found settings: \n {}", + best.as_json()?.to_colored_json_auto()? + ); + } + + pb.inc(1); + } + + pb.finish_with_message("Calibration Done."); + + if found_params.is_empty() { + return Err("calibration failed, could not find any suitable parameters given the calibration dataset".into()); + } + + debug!("Found {} sets of parameters", found_params.len()); + + // now find the best params according to the target + let mut best_params = match target { + CalibrationTarget::Resources { .. } => { + let mut param_iterator = found_params.iter().sorted_by_key(|p| p.run_args.logrows); + + let min_logrows = param_iterator + .next() + .ok_or("no params found")? + .run_args + .logrows; + + // pick the ones that have the minimum logrows but also the largest scale: + // this is the best tradeoff between resource usage and accuracy + found_params + .iter() + .filter(|p| p.run_args.logrows == min_logrows) + .max_by_key(|p| { + ( + p.run_args.input_scale, + p.run_args.param_scale, + // we want the largest rebase multiplier as it means we can use less constraints + p.run_args.scale_rebase_multiplier, + ) + }) + .ok_or("no params found")? + .clone() + } + CalibrationTarget::Accuracy => { + let param_iterator = found_params.iter().sorted_by_key(|p| { + ( + p.run_args.input_scale, + p.run_args.param_scale, + // we want the largest rebase multiplier as it means we can use less constraints + p.run_args.scale_rebase_multiplier, + ) + }); + + let last = param_iterator.last().ok_or("no params found")?; + let max_scale = ( + last.run_args.input_scale, + last.run_args.param_scale, + last.run_args.scale_rebase_multiplier, + ); + + // pick the ones that have the max scale but also the smallest logrows: + // this is the best tradeoff between resource usage and accuracy + found_params + .iter() + .filter(|p| { + ( + p.run_args.input_scale, + p.run_args.param_scale, + p.run_args.scale_rebase_multiplier, + ) == max_scale + }) + .min_by_key(|p| p.run_args.logrows) + .ok_or("no params found")? + .clone() + } + }; + + if matches!(target, CalibrationTarget::Resources { col_overflow: true }) { + let lookup_log_rows = ((best_params.run_args.lookup_range.1 + - best_params.run_args.lookup_range.0) as f32) + .log2() + .ceil() as u32 + + 1; + let mut reduction = std::cmp::max( + (best_params + .model_instance_shapes + .iter() + .map(|x| x.iter().product::()) + .sum::() as f32) + .log2() + .ceil() as u32 + + 1, + lookup_log_rows, + ); + reduction = std::cmp::max(reduction, crate::graph::MIN_LOGROWS); + + info!( + "logrows > bits, shrinking logrows: {} -> {}", + best_params.run_args.logrows, reduction + ); + + best_params.run_args.logrows = reduction; + } + + best_params.save(&settings_path)?; + + debug!("Saved parameters."); + + Ok(()) +} + +pub(crate) fn mock( + compiled_circuit_path: PathBuf, + data_path: PathBuf, +) -> Result<(), Box> { + // mock should catch any issues by default so we set it to safe + let mut circuit = GraphCircuit::load(compiled_circuit_path)?; + + let data = GraphWitness::from_path(data_path)?; + + circuit.load_graph_witness(&data)?; + + let public_inputs = circuit.prepare_public_inputs(&data)?; + + info!("Mock proof"); + + let prover = halo2_proofs::dev::MockProver::run( + circuit.settings().run_args.logrows, + &circuit, + vec![public_inputs], + ) + .map_err(Box::::from)?; + prover + .verify_par() + .map_err(|e| Box::::from(ExecutionError::VerifyError(e)))?; + Ok(()) +} + +pub(crate) fn print_proof_hex(proof_path: PathBuf) -> Result<(), Box> { + let proof = Snark::load::>(&proof_path)?; + for instance in proof.instances { + println!("{:?}", instance); + } + info!("{}", hex::encode(proof.proof)); + Ok(()) +} + +#[cfg(feature = "render")] +pub(crate) fn render(model: PathBuf, output: PathBuf, args: RunArgs) -> Result<(), Box> { + let circuit = GraphCircuit::from_run_args(&args, &model)?; + info!("Rendering circuit"); + + // Create the area we want to draw on. + // We could use SVGBackend if we want to render to .svg instead. + // for an overview of how to interpret these plots, see https://zcash.github.io/halo2/user/dev-tools.html + let root = BitMapBackend::new(&output, (512, 512)).into_drawing_area(); + root.fill(&TRANSPARENT)?; + let root = root.titled("Layout", ("sans-serif", 20))?; + + halo2_proofs::dev::CircuitLayout::default() + // We hide labels, else most circuits become impossible to decipher because of overlaid text + .show_labels(false) + .render(circuit.settings().run_args.logrows, &circuit, &root)?; + Ok(()) +} + +#[cfg(not(target_arch = "wasm32"))] +pub(crate) fn create_evm_verifier( + vk_path: PathBuf, + srs_path: PathBuf, + settings_path: PathBuf, + sol_code_path: PathBuf, + abi_path: PathBuf, +) -> Result<(), Box> { + check_solc_requirement(); + let circuit_settings = GraphSettings::load(&settings_path)?; + let params = load_params_cmd(srs_path, circuit_settings.run_args.logrows)?; + + let num_instance = circuit_settings.total_instances(); + let num_instance: usize = num_instance.iter().sum::(); + + let vk = load_vk::, Fr, GraphCircuit>(vk_path, circuit_settings)?; + trace!("params computed"); + + let generator = halo2_solidity_verifier::SolidityGenerator::new( + ¶ms, + &vk, + halo2_solidity_verifier::BatchOpenScheme::Bdfg21, + num_instance, + ); + let verifier_solidity = generator.render()?; + + File::create(sol_code_path.clone())?.write_all(verifier_solidity.as_bytes())?; + + // fetch abi of the contract + let (abi, _, _) = get_contract_artifacts(sol_code_path, "Halo2Verifier", 0)?; + // save abi to file + serde_json::to_writer(std::fs::File::create(abi_path)?, &abi)?; + + Ok(()) +} + +#[cfg(not(target_arch = "wasm32"))] +pub(crate) fn create_evm_data_attestation( + vk_path: PathBuf, + srs_path: PathBuf, + settings_path: PathBuf, + sol_code_path: PathBuf, + abi_path: PathBuf, + input: PathBuf, +) -> Result<(), Box> { + use crate::graph::{DataSource, VarVisibility}; + check_solc_requirement(); + + let settings = GraphSettings::load(&settings_path)?; + let params = load_params_cmd(srs_path, settings.run_args.logrows)?; + + let visibility = VarVisibility::from_args(&settings.run_args)?; + + let num_instance = settings.total_instances(); + let num_instance: usize = num_instance.iter().sum::(); + + let vk = load_vk::, Fr, GraphCircuit>(vk_path, settings.clone())?; + trace!("params computed"); + + let yul_code: YulCode = gen_evm_verifier(¶ms, &vk, num_instance)?; + + let mut f = File::create(sol_code_path.clone())?; + let _ = f.write(yul_code.as_bytes()); + + let data = GraphData::from_path(input)?; + + let output_data = if let Some(DataSource::OnChain(source)) = data.output_data { + if visibility.output.is_private() { + return Err("private output data on chain is not supported on chain".into()); + } + let mut on_chain_output_data = vec![]; + for call in source.calls { + on_chain_output_data.push(call); + } + Some(on_chain_output_data) + } else { + None + }; + + let input_data = if let DataSource::OnChain(source) = data.input_data { + if visibility.input.is_private() { + return Err("private input data on chain is not supported on chain".into()); + } + let mut on_chain_input_data = vec![]; + for call in source.calls { + on_chain_input_data.push(call); + } + Some(on_chain_input_data) + } else { + None + }; + + if input_data.is_some() || output_data.is_some() { + let output = fix_da_sol(input_data, output_data)?; + let mut f = File::create(sol_code_path.clone())?; + let _ = f.write(output.as_bytes()); + // fetch abi of the contract + let (abi, _, _) = get_contract_artifacts(sol_code_path, "DataAttestation", 0)?; + // save abi to file + serde_json::to_writer(std::fs::File::create(abi_path)?, &abi)?; + } else { + return Err( + "Neither input or output data source is on-chain. Atleast one must be on chain.".into(), + ); + } + Ok(()) +} + +#[cfg(not(target_arch = "wasm32"))] +pub(crate) async fn deploy_da_evm( + data: PathBuf, + settings_path: PathBuf, + sol_code_path: PathBuf, + rpc_url: Option, + addr_path: PathBuf, + runs: usize, + private_key: Option, +) -> Result<(), Box> { + check_solc_requirement(); + let contract_address = deploy_da_verifier_via_solidity( + settings_path, + data, + sol_code_path, + rpc_url.as_deref(), + runs, + private_key.as_deref(), + ) + .await?; + info!("Contract deployed at: {}", contract_address); + + let mut f = File::create(addr_path)?; + write!(f, "{:#?}", contract_address)?; + + Ok(()) +} + +#[cfg(not(target_arch = "wasm32"))] +pub(crate) async fn deploy_evm( + sol_code_path: PathBuf, + rpc_url: Option, + addr_path: PathBuf, + runs: usize, + private_key: Option, +) -> Result<(), Box> { + check_solc_requirement(); + let contract_address = deploy_verifier_via_solidity( + sol_code_path, + rpc_url.as_deref(), + runs, + private_key.as_deref(), + ) + .await?; + + info!("Contract deployed at: {:#?}", contract_address); + + let mut f = File::create(addr_path)?; + write!(f, "{:#?}", contract_address)?; + Ok(()) +} + +#[cfg(not(target_arch = "wasm32"))] +pub(crate) async fn verify_evm( + proof_path: PathBuf, + addr_verifier: H160, + rpc_url: Option, + addr_da: Option, +) -> Result<(), Box> { + use crate::eth::verify_proof_with_data_attestation; + check_solc_requirement(); + + let proof = Snark::load::>(&proof_path)?; + + let result = if let Some(addr_da) = addr_da { + verify_proof_with_data_attestation( + proof.clone(), + addr_verifier, + addr_da, + rpc_url.as_deref(), + ) + .await? + } else { + verify_proof_via_solidity(proof.clone(), addr_verifier, rpc_url.as_deref()).await? + }; + + info!("Solidity verification result: {}", result); + + if !result { + return Err("Solidity verification failed".into()); + } + + Ok(()) +} + +#[cfg(not(target_arch = "wasm32"))] +pub(crate) fn create_evm_aggregate_verifier( + vk_path: PathBuf, + srs_path: PathBuf, + sol_code_path: PathBuf, + abi_path: PathBuf, + circuit_settings: Vec, +) -> Result<(), Box> { + check_solc_requirement(); + let params: ParamsKZG = load_srs::>(srs_path)?; + + let mut settings: Vec = vec![]; + + for path in circuit_settings.iter() { + let s = GraphSettings::load(path)?; + settings.push(s); + } + + let num_instance: usize = settings + .iter() + .map(|s| s.total_instances().iter().sum::()) + .sum(); + + let num_instance = AggregationCircuit::num_instance(num_instance); + assert_eq!(num_instance.len(), 1); + let num_instance = num_instance[0]; + + let agg_vk = load_vk::, Fr, AggregationCircuit>(vk_path, ())?; + + let mut generator = halo2_solidity_verifier::SolidityGenerator::new( + ¶ms, + &agg_vk, + halo2_solidity_verifier::BatchOpenScheme::Bdfg21, + num_instance, + ); + + let acc_encoding = halo2_solidity_verifier::AccumulatorEncoding::new( + 0, + AggregationCircuit::num_limbs(), + AggregationCircuit::num_bits(), + ); + + generator = generator.set_acc_encoding(Some(acc_encoding)); + + let verifier_solidity = generator.render()?; + + File::create(sol_code_path.clone())?.write_all(verifier_solidity.as_bytes())?; + + // fetch abi of the contract + let (abi, _, _) = get_contract_artifacts(sol_code_path, "Halo2Verifier", 0)?; + // save abi to file + serde_json::to_writer(std::fs::File::create(abi_path)?, &abi)?; + + Ok(()) +} + +pub(crate) fn compile_circuit( + model_path: PathBuf, + compiled_circuit: PathBuf, + settings_path: PathBuf, +) -> Result<(), Box> { + let settings = GraphSettings::load(&settings_path)?; + let circuit = GraphCircuit::from_settings(&settings, &model_path, CheckMode::UNSAFE)?; + circuit.save(compiled_circuit)?; + Ok(()) +} + +pub(crate) fn setup( + compiled_circuit: PathBuf, + srs_path: PathBuf, + vk_path: PathBuf, + pk_path: PathBuf, + witness: Option, +) -> Result<(), Box> { + // these aren't real values so the sanity checks are mostly meaningless + let mut circuit = GraphCircuit::load(compiled_circuit)?; + if let Some(witness) = witness { + let data = GraphWitness::from_path(witness)?; + circuit.load_graph_witness(&data)?; + } + + let params = load_params_cmd(srs_path, circuit.settings().run_args.logrows)?; + + let pk = create_keys::, Fr, GraphCircuit>(&circuit, ¶ms) + .map_err(Box::::from)?; + + save_vk::>(&vk_path, pk.get_vk())?; + save_pk::>(&pk_path, &pk)?; + Ok(()) +} + +#[cfg(not(target_arch = "wasm32"))] +pub(crate) async fn setup_test_evm_witness( + data_path: PathBuf, + compiled_circuit_path: PathBuf, + test_data: PathBuf, + rpc_url: Option, + input_source: TestDataSource, + output_source: TestDataSource, +) -> Result<(), Box> { + use crate::graph::TestOnChainData; + + info!("run this command in background to keep the instance running for testing"); + let mut data = GraphData::from_path(data_path)?; + let mut circuit = GraphCircuit::load(compiled_circuit_path)?; + + // if both input and output are from files fail + if matches!(input_source, TestDataSource::File) && matches!(output_source, TestDataSource::File) + { + return Err("Both input and output cannot be from files".into()); + } + + let test_on_chain_data = TestOnChainData { + data: test_data.clone(), + rpc: rpc_url, + data_sources: TestSources { + input: input_source, + output: output_source, + }, + }; + + circuit + .populate_on_chain_test_data(&mut data, test_on_chain_data) + .await?; + + Ok(()) +} + +#[cfg(not(target_arch = "wasm32"))] +use crate::pfsys::ProofType; +#[cfg(not(target_arch = "wasm32"))] +pub(crate) async fn test_update_account_calls( + addr: H160, + data: PathBuf, + rpc_url: Option, +) -> Result<(), Box> { + use crate::eth::update_account_calls; + + check_solc_requirement(); + update_account_calls(addr, data, rpc_url.as_deref()).await?; + + Ok(()) +} + +#[cfg(not(target_arch = "wasm32"))] +#[allow(clippy::too_many_arguments)] +pub(crate) fn prove( + data_path: PathBuf, + compiled_circuit_path: PathBuf, + pk_path: PathBuf, + proof_path: Option, + srs_path: PathBuf, + proof_type: ProofType, + check_mode: CheckMode, +) -> Result, Box> { + use crate::pfsys::ProofSplitCommit; + + let data = GraphWitness::from_path(data_path)?; + let mut circuit = GraphCircuit::load(compiled_circuit_path)?; + + circuit.load_graph_witness(&data)?; + + let public_inputs = circuit.prepare_public_inputs(&data)?; + + let circuit_settings = circuit.settings().clone(); + + let params = load_params_cmd(srs_path, circuit_settings.run_args.logrows)?; + + let pk = load_pk::, Fr, GraphCircuit>(pk_path, circuit_settings) + .map_err(Box::::from)?; + + trace!("params computed"); + + let strategy: StrategyType = proof_type.into(); + let transcript: TranscriptType = proof_type.into(); + let proof_split_commits: Option = data.into(); + + // creates and verifies the proof + let snark = match strategy { + StrategyType::Single => { + let strategy = KZGSingleStrategy::new(¶ms); + create_proof_circuit_kzg( + circuit, + ¶ms, + Some(public_inputs), + &pk, + transcript, + strategy, + check_mode, + proof_split_commits, + )? + } + StrategyType::Accum => { + let strategy = AccumulatorStrategy::new(¶ms); + create_proof_circuit_kzg( + circuit, + ¶ms, + Some(public_inputs), + &pk, + transcript, + strategy, + check_mode, + proof_split_commits, + )? + } + }; + + if let Some(proof_path) = proof_path { + snark.save(&proof_path)?; + } + + Ok(snark) +} + +#[cfg(not(target_arch = "wasm32"))] +pub(crate) fn fuzz( + compiled_circuit_path: PathBuf, + data_path: PathBuf, + transcript: TranscriptType, + num_runs: usize, +) -> Result<(), Box> { + check_solc_requirement(); + let passed = AtomicBool::new(true); + + // these aren't real values so the sanity checks are mostly meaningless + let mut circuit = GraphCircuit::load(compiled_circuit_path)?; + let logrows = circuit.settings().run_args.logrows; + + info!("setting up tests"); + + let _r = Gag::stdout()?; + let params = gen_srs::>(logrows); + + let data = GraphWitness::from_path(data_path)?; + + let pk = create_keys::, Fr, GraphCircuit>(&circuit, ¶ms) + .map_err(Box::::from)?; + + circuit.load_graph_witness(&data)?; + + let public_inputs = circuit.prepare_public_inputs(&data)?; + + let strategy = KZGSingleStrategy::new(¶ms); + std::mem::drop(_r); + + info!("starting fuzzing"); + + info!("fuzzing pk"); + + let fuzz_pk = || { + let new_params = gen_srs::>(logrows); + + let bad_pk = + create_keys::, Fr, GraphCircuit>(&circuit, &new_params) + .map_err(|_| ())?; + + let bad_proof = create_proof_circuit_kzg( + circuit.clone(), + ¶ms, + Some(public_inputs.clone()), + &bad_pk, + transcript, + strategy.clone(), + CheckMode::UNSAFE, + None, + ) + .map_err(|_| ())?; + + verify_proof_circuit_kzg( + params.verifier_params(), + bad_proof, + pk.get_vk(), + strategy.clone(), + ) + .map_err(|_| ()) + }; + + run_fuzz_fn(num_runs, fuzz_pk, &passed); + + info!("fuzzing public inputs"); + + let fuzz_public_inputs = || { + let bad_inputs: Vec = (0..public_inputs.len()) + .map(|_| Fr::random(rand::rngs::OsRng)) + .collect(); + + let bad_proof = create_proof_circuit_kzg( + circuit.clone(), + ¶ms, + Some(bad_inputs.clone()), + &pk, + transcript, + strategy.clone(), + CheckMode::UNSAFE, + None, + ) + .map_err(|_| ())?; + + verify_proof_circuit_kzg( + params.verifier_params(), + bad_proof, + pk.get_vk(), + strategy.clone(), + ) + .map_err(|_| ()) + }; + + run_fuzz_fn(num_runs, fuzz_public_inputs, &passed); + + info!("fuzzing vk"); + + let proof = create_proof_circuit_kzg( + circuit.clone(), + ¶ms, + Some(public_inputs.clone()), + &pk, + transcript, + strategy.clone(), + CheckMode::SAFE, + None, + )?; + + let fuzz_vk = || { + let new_params = gen_srs::>(logrows); + + let bad_pk = + create_keys::, Fr, GraphCircuit>(&circuit, &new_params) + .map_err(|_| ())?; + + let bad_vk = bad_pk.get_vk(); + + verify_proof_circuit_kzg( + params.verifier_params(), + proof.clone(), + bad_vk, + strategy.clone(), + ) + .map_err(|_| ()) + }; + + run_fuzz_fn(num_runs, fuzz_vk, &passed); + + info!("fuzzing proof bytes"); + + let fuzz_proof_bytes = || { + let mut rng = rand::thread_rng(); + + let bad_proof_bytes: Vec = (0..proof.proof.len()) + .map(|_| rng.gen_range(0..20)) + .collect(); + + let bad_proof = Snark::<_, _> { + instances: proof.instances.clone(), + proof: bad_proof_bytes, + protocol: proof.protocol.clone(), + transcript_type: transcript, + split: None, + }; + + verify_proof_circuit_kzg( + params.verifier_params(), + bad_proof, + pk.get_vk(), + strategy.clone(), + ) + .map_err(|_| ()) + }; + + run_fuzz_fn(num_runs, fuzz_proof_bytes, &passed); + + info!("fuzzing proof instances"); + + let fuzz_proof_instances = || { + let mut bad_inputs = vec![vec![]]; + + for l in &proof.instances { + bad_inputs.push( + (0..l.len()) + .map(|_| Fr::random(rand::rngs::OsRng)) + .collect(), + ); + } + + let bad_proof = Snark::<_, _> { + instances: bad_inputs.clone(), + proof: proof.proof.clone(), + protocol: proof.protocol.clone(), + transcript_type: transcript, + split: None, + }; + + verify_proof_circuit_kzg( + params.verifier_params(), + bad_proof, + pk.get_vk(), + strategy.clone(), + ) + .map_err(|_| ()) + }; + + run_fuzz_fn(num_runs, fuzz_proof_instances, &passed); + + if !passed.into_inner() { + Err("fuzzing failed".into()) + } else { + Ok(()) + } +} + +#[cfg(not(target_arch = "wasm32"))] +pub(crate) fn run_fuzz_fn( + num_runs: usize, + f: impl Fn() -> Result<(), ()> + std::marker::Sync + std::marker::Send, + passed: &AtomicBool, +) { + let num_failures = AtomicI64::new(0); + let _r = Gag::stdout().unwrap(); + + let pb = init_bar(num_runs as u64); + pb.set_message("fuzzing..."); + (0..num_runs).into_par_iter().for_each(|_| { + let result = f(); + if result.is_ok() { + passed.swap(false, Ordering::Relaxed); + num_failures.fetch_add(1, Ordering::Relaxed); + } + pb.inc(1); + }); + pb.finish_with_message("Done."); + std::mem::drop(_r); + info!( + "num failures: {} out of {}", + num_failures.load(Ordering::Relaxed), + num_runs + ); +} + +pub(crate) fn swap_proof_commitments( + proof_path: PathBuf, + witness: PathBuf, +) -> Result<(), Box> { + let snark = Snark::load::>(&proof_path)?; + let witness = GraphWitness::from_path(witness)?; + let commitments = witness.get_kzg_commitments(); + + if commitments.is_empty() { + log::warn!("no commitments found in witness"); + } + + let snark_new = swap_proof_commitments_kzg(&snark, &commitments)?; + + if snark_new.proof != *snark.proof { + log::warn!("swap proof has created a different proof"); + } + + snark_new.save(&proof_path)?; + Ok(()) +} + +pub(crate) fn mock_aggregate( + aggregation_snarks: Vec, + logrows: u32, + split_proofs: bool, +) -> Result<(), Box> { + let mut snarks = vec![]; + for proof_path in aggregation_snarks.iter() { + snarks.push(Snark::load::>(proof_path)?); + } + // proof aggregation + #[cfg(not(target_arch = "wasm32"))] + let pb = { + let pb = init_spinner(); + pb.set_message("Aggregating (may take a while)..."); + pb + }; + + let circuit = AggregationCircuit::new(&G1Affine::generator().into(), snarks, split_proofs)?; + + let prover = halo2_proofs::dev::MockProver::run(logrows, &circuit, vec![circuit.instances()]) + .map_err(Box::::from)?; + prover + .verify_par() + .map_err(|e| Box::::from(ExecutionError::VerifyError(e)))?; + #[cfg(not(target_arch = "wasm32"))] + pb.finish_with_message("Done."); + Ok(()) +} + +pub(crate) fn setup_aggregate( + sample_snarks: Vec, + vk_path: PathBuf, + pk_path: PathBuf, + srs_path: PathBuf, + logrows: u32, + split_proofs: bool, +) -> Result<(), Box> { + // the K used for the aggregation circuit + let params = load_params_cmd(srs_path, logrows)?; + + let mut snarks = vec![]; + for proof_path in sample_snarks.iter() { + snarks.push(Snark::load::>(proof_path)?); + } + + let agg_circuit = AggregationCircuit::new(¶ms.get_g()[0].into(), snarks, split_proofs)?; + let agg_pk = + create_keys::, Fr, AggregationCircuit>(&agg_circuit, ¶ms)?; + + let agg_vk = agg_pk.get_vk(); + + // now save + save_vk::>(&vk_path, agg_vk)?; + save_pk::>(&pk_path, &agg_pk)?; + Ok(()) +} + +#[allow(clippy::too_many_arguments)] +pub(crate) fn aggregate( + proof_path: PathBuf, + aggregation_snarks: Vec, + pk_path: PathBuf, + srs_path: PathBuf, + transcript: TranscriptType, + logrows: u32, + check_mode: CheckMode, + split_proofs: bool, +) -> Result<(), Box> { + // the K used for the aggregation circuit + let params = load_params_cmd(srs_path, logrows)?; + + let mut snarks = vec![]; + for proof_path in aggregation_snarks.iter() { + snarks.push(Snark::load::>(proof_path)?); + } + + let agg_pk = load_pk::, Fr, AggregationCircuit>(pk_path, ())?; + // proof aggregation + #[cfg(not(target_arch = "wasm32"))] + let pb = { + let pb = init_spinner(); + pb.set_message("Aggregating (may take a while)..."); + pb + }; + + { + let agg_circuit = AggregationCircuit::new(¶ms.get_g()[0].into(), snarks, split_proofs)?; + + let now = Instant::now(); + let snark = create_proof_circuit_kzg( + agg_circuit.clone(), + ¶ms, + Some(agg_circuit.instances()), + &agg_pk, + transcript, + AccumulatorStrategy::new(¶ms), + check_mode, + None, + )?; + + let elapsed = now.elapsed(); + info!( + "Aggregation proof took {}.{}", + elapsed.as_secs(), + elapsed.subsec_millis() + ); + snark.save(&proof_path)?; + } + #[cfg(not(target_arch = "wasm32"))] + pb.finish_with_message("Done."); + + Ok(()) +} + +pub(crate) fn verify( + proof_path: PathBuf, + settings_path: PathBuf, + vk_path: PathBuf, + srs_path: PathBuf, +) -> Result<(), Box> { + let circuit_settings = GraphSettings::load(&settings_path)?; + let params = load_params_cmd(srs_path, circuit_settings.run_args.logrows)?; + let proof = Snark::load::>(&proof_path)?; + + let strategy = KZGSingleStrategy::new(params.verifier_params()); + let vk = load_vk::, Fr, GraphCircuit>(vk_path, circuit_settings)?; + let now = Instant::now(); + let result = verify_proof_circuit_kzg(params.verifier_params(), proof, &vk, strategy); + let elapsed = now.elapsed(); + info!( + "verify took {}.{}", + elapsed.as_secs(), + elapsed.subsec_millis() + ); + info!("verified: {}", result.is_ok()); + result.map_err(|e| e.into()) +} + +pub(crate) fn verify_aggr( + proof_path: PathBuf, + vk_path: PathBuf, + srs_path: PathBuf, + logrows: u32, +) -> Result<(), Box> { + let params = load_params_cmd(srs_path, logrows)?; + + let proof = Snark::load::>(&proof_path)?; + + let strategy = AccumulatorStrategy::new(params.verifier_params()); + let vk = load_vk::, Fr, AggregationCircuit>(vk_path, ())?; + let now = Instant::now(); + let result = verify_proof_circuit_kzg(¶ms, proof, &vk, strategy); + + let elapsed = now.elapsed(); + info!( + "verify took {}.{}", + elapsed.as_secs(), + elapsed.subsec_millis() + ); + info!("verified: {}", result.is_ok()); + result?; + Ok(()) +} + +/// helper function to handle graphql errors +async fn parse_response( + response: reqwest::Response, +) -> Result> { + // Check if the response status is success + if !response.status().is_success() { + let status = response.status(); + let error_message = format!("Request failed with status code: {}", status); + return Err(Box::new(std::io::Error::new( + std::io::ErrorKind::Other, + error_message, + ))); + } + + let response_body = response.json::().await?; + + // Check if 'data' is null and 'errors' are present + if response_body.get("data").is_none() || response_body.get("data").unwrap().is_null() { + if let Some(errors) = response_body.get("errors") { + let error_messages: Vec = errors + .as_array() + .unwrap() + .iter() + .map(|error| error["message"].as_str().unwrap_or_default().to_string()) + .collect(); + + let custom_error_message = format!("An error occurred: {}", error_messages.join(", ")); + return Err(Box::new(std::io::Error::new( + std::io::ErrorKind::Other, + custom_error_message, + ))); + } else { + let error_message = + "An error occurred: Response contains null data but no error details."; + return Err(Box::new(std::io::Error::new( + std::io::ErrorKind::Other, + error_message, + ))); + } + } + + Ok(response_body) +} + +/// Retrieves the user's credentials from the hub +pub(crate) async fn get_hub_credentials( + api_key: Option<&str>, + url: Option<&str>, + username: &str, +) -> Result> { + let client = reqwest::Client::new(); + let request_body = serde_json::json!({ + "query": r#" + query GetOrganizationId($username: String!) { + organizations(name: $username) { + id + name + } + } + "#, + "variables": { + "username": username, + } + }); + let url = url.unwrap_or("https://hub-staging.ezkl.xyz/graphql"); + let api_key = api_key.unwrap_or("ed896983-2ec3-4aaf-afa7-f01299f3d61f"); + + let response = client + .post(url) + .header("API-Key", format!("{}", api_key)) + .json(&request_body) + .send() + .await?; + + // Using the parse_response helper function + let response_body = parse_response(response).await?; + + // Extracting the organizations data + let organizations: crate::hub::Organizations = + serde_json::from_value(response_body["data"].clone())?; + + log::info!( + "Organization ID : {}", + organizations.as_json()?.to_colored_json_auto()? + ); + Ok(organizations) +} + +/// Deploy a model +pub(crate) async fn deploy_model( + api_key: Option<&str>, + url: Option<&str>, + model: &Path, + input: &Path, + name: &str, + organization_id: &str, + args: &RunArgs, + target: &CalibrationTarget, +) -> Result> { + let model_file = tokio::fs::File::open(model.canonicalize()?).await?; + // read file body stream + let stream = FramedRead::new(model_file, BytesCodec::new()); + let model_file_body = reqwest::Body::wrap_stream(stream); + + let model_file = reqwest::multipart::Part::stream(model_file_body).file_name("uncompiledModel"); + + let input_file = tokio::fs::File::open(input.canonicalize()?).await?; + // read file body stream + let stream = FramedRead::new(input_file, BytesCodec::new()); + let input_file_body = reqwest::Body::wrap_stream(stream); + + //make form part of file + let input_file = reqwest::multipart::Part::stream(input_file_body).file_name("input"); + + // the graphql request map + let map = r#"{ + "uncompiledModel": [ + "variables.uncompiledModel" + ], + "input": [ + "variables.input" + ] + }"#; + + let operations = serde_json::json!({ + "query": "mutation($uncompiledModel: Upload!, $input: Upload!, $organizationId: String!, $name: String!, $calibrationTarget: String!, $tolerance: Float!, $inputVisibility: String!, $outputVisibility: String!, $paramVisibility: String!) { + generateArtifact( + name: $name, + description: $name, + uncompiledModel: $uncompiledModel, + input: $input, + organizationId: $organizationId, + calibrationTarget: $calibrationTarget, + tolerance: $tolerance, + inputVisibility: $inputVisibility, + outputVisibility: $outputVisibility, + paramVisibility: $paramVisibility, + ) { + id + name + status + errors + } + }", + "variables": { + "name": name, + "uncompiledModel": null, + "input": null, + "organizationId": organization_id, + "calibrationTarget": target.to_string(), + "tolerance": args.tolerance.val, + "inputVisibility": args.input_visibility.to_string(), + "outputVisibility": args.output_visibility.to_string(), + "paramVisibility": args.param_visibility.to_string(), + } + }) + .to_string(); + + // now the form data + let mut form = reqwest::multipart::Form::new(); + form = form + .text("operations", operations) + .text("map", map) + .part("uncompiledModel", model_file) + .part("input", input_file); + + let client = reqwest::Client::new(); + let url = url.unwrap_or("https://hub-staging.ezkl.xyz/graphql"); + let api_key = api_key.unwrap_or("ed896983-2ec3-4aaf-afa7-f01299f3d61f"); + //send request + let response = client + .post(url) + .header("API-Key", format!("{}", api_key)) + .multipart(form) + .send() + .await?; + let response_body = parse_response(response).await?; + let artifact_data: crate::hub::Artifact = + serde_json::from_value(response_body["data"]["generateArtifact"].clone())?; + log::info!( + "Artifact Data : {}", + artifact_data.as_json()?.to_colored_json_auto()? + ); + Ok(artifact_data) +} + +/// Get the artifact from the hub +pub(crate) async fn get_deployed_model( + api_key: Option<&str>, + url: Option<&str>, + id: &str, +) -> Result> { + let query = serde_json::json!({ + "query": "query getArtifact($id: String!){ + artifact(id: $id) { + id + name + status + errors + } + }", + "variables": { + "id": id, + } + }); + let client = reqwest::Client::new(); + let url = url.unwrap_or("https://hub-staging.ezkl.xyz/graphql"); + let api_key = api_key.unwrap_or("ed896983-2ec3-4aaf-afa7-f01299f3d61f"); + //send request + let response = client + .post(url) + .header("API-Key", format!("{}", api_key)) + .json(&query) + .send() + .await?; + let response_body = parse_response(response).await?; + let artifact_data: crate::hub::Artifact = + serde_json::from_value(response_body["data"]["artifact"].clone())?; + log::info!( + "Artifact Data : {}", + artifact_data.as_json()?.to_colored_json_auto()? + ); + Ok(artifact_data) +} + +/// Generates proofs on the hub +pub async fn prove_hub( + api_key: Option<&str>, + url: Option<&str>, + id: &str, + input: &Path, +) -> Result> { + let input_file = tokio::fs::File::open(input.canonicalize()?).await?; + let stream = FramedRead::new(input_file, BytesCodec::new()); + let input_file_body = reqwest::Body::wrap_stream(stream); + + let input_file = reqwest::multipart::Part::stream(input_file_body).file_name("input"); + + let map = r#"{ + "input": [ + "variables.input" + ] + }"#; + + let operations = serde_json::json!({ + "query": r#" + mutation($input: Upload!, $id: String!) { + initiateProof(input: $input, id: $id) { + id + } + } + "#, + "variables": { + "input": null, + "id": id, + } + }) + .to_string(); + + let mut form = reqwest::multipart::Form::new(); + form = form + .text("operations", operations) + .text("map", map) + .part("input", input_file); + let url = url.unwrap_or("https://hub-staging.ezkl.xyz/graphql"); + let api_key = api_key.unwrap_or("ed896983-2ec3-4aaf-afa7-f01299f3d61f"); + let client = reqwest::Client::new(); + let response = client + .post(url) + .header("API-Key", format!("{}", api_key)) + .multipart(form) + .send() + .await?; + + let response_body = parse_response(response).await?; + + // Check if 'data' is null and 'errors' are present + if response_body.get("data").is_none() || response_body.get("data").unwrap().is_null() { + if let Some(errors) = response_body.get("errors") { + let error_messages: Vec = errors + .as_array() + .unwrap() + .iter() + .map(|error| error["message"].as_str().unwrap_or_default().to_string()) + .collect(); + + let custom_error_message = format!("An error occurred: {}", error_messages.join(", ")); + log::error!("{}", custom_error_message); + return Err(Box::new(std::io::Error::new( + std::io::ErrorKind::Other, + custom_error_message, + ))); + } else { + let error_message = + "An error occurred: Response contains null data but no error details."; + log::error!("{}", error_message); + return Err(Box::new(std::io::Error::new( + std::io::ErrorKind::Other, + error_message, + ))); + } + } + let proof_id: crate::hub::Proof = + serde_json::from_value(response_body["data"]["initiateProof"].clone())?; + log::info!("Proof ID : {}", proof_id.as_json()?.to_colored_json_auto()?); + Ok(proof_id) +} + +/// Fetches proofs from the hub +pub(crate) async fn get_hub_proof( + api_key: Option<&str>, + url: Option<&str>, + id: &str, +) -> Result> { + let client = reqwest::Client::new(); + let request_body = serde_json::json!({ + "query": format!(r#" + query {{ + getProof(id: "{}") {{ + id + artifact {{ id name }} + status + proof + instances + transcriptType + }} + }} + "#, id), + }); + let url = url.unwrap_or("https://hub-staging.ezkl.xyz/graphql"); + let api_key = api_key.unwrap_or("ed896983-2ec3-4aaf-afa7-f01299f3d61f"); + + let response = client + .post(url) + .header("API-Key", format!("{:?}", api_key)) + .json(&request_body) + .send() + .await?; + let response_body = parse_response(response).await?; + + let proof: crate::hub::Proof = + serde_json::from_value(response_body["data"]["getProof"].clone())?; + + log::info!("Proof : {}", proof.as_json()?.to_colored_json_auto()?); + Ok(proof) +} + +/// helper function for load_params +pub(crate) fn load_params_cmd( + srs_path: PathBuf, + logrows: u32, +) -> Result, Box> { + let mut params: ParamsKZG = load_srs::>(srs_path)?; + info!("downsizing params to {} logrows", logrows); + if logrows < params.k() { + params.downsize(logrows); + } + Ok(params) +} diff --git a/mnist_ezkl/src/fieldutils.rs b/mnist_ezkl/src/fieldutils.rs new file mode 100644 index 0000000..f031684 --- /dev/null +++ b/mnist_ezkl/src/fieldutils.rs @@ -0,0 +1,106 @@ +use halo2_proofs::arithmetic::Field; +/// Utilities for converting from Halo2 PrimeField types to integers (and vice-versa). +use halo2curves::ff::PrimeField; + +/// Converts an i32 to a PrimeField element. +pub fn i32_to_felt(x: i32) -> F { + if x >= 0 { + F::from(x as u64) + } else { + -F::from(x.unsigned_abs() as u64) + } +} + +/// Converts an i128 to a PrimeField element. +pub fn i128_to_felt(x: i128) -> F { + if x >= 0 { + F::from_u128(x as u128) + } else { + -F::from_u128((-x) as u128) + } +} + +/// Converts a PrimeField element to an i32. +pub fn felt_to_i32(x: F) -> i32 { + if x > F::from(i32::MAX as u64) { + let rep = (-x).to_repr(); + let negtmp: &[u8] = rep.as_ref(); + let lower_32 = u32::from_le_bytes(negtmp[..4].try_into().unwrap()); + -(lower_32 as i32) + } else { + let rep = (x).to_repr(); + let tmp: &[u8] = rep.as_ref(); + let lower_32 = u32::from_le_bytes(tmp[..4].try_into().unwrap()); + lower_32 as i32 + } +} + +/// Converts a PrimeField element to an i128. +pub fn felt_to_f64(x: F) -> f64 { + if x > F::from_u128(i128::MAX as u128) { + let rep = (-x).to_repr(); + let negtmp: &[u8] = rep.as_ref(); + let lower_128: u128 = u128::from_le_bytes(negtmp[..16].try_into().unwrap()); + -(lower_128 as f64) + } else { + let rep = (x).to_repr(); + let tmp: &[u8] = rep.as_ref(); + let lower_128: u128 = u128::from_le_bytes(tmp[..16].try_into().unwrap()); + lower_128 as f64 + } +} + +/// Converts a PrimeField element to an i128. +pub fn felt_to_i128(x: F) -> i128 { + if x > F::from_u128(i128::MAX as u128) { + let rep = (-x).to_repr(); + let negtmp: &[u8] = rep.as_ref(); + let lower_128: u128 = u128::from_le_bytes(negtmp[..16].try_into().unwrap()); + -(lower_128 as i128) + } else { + let rep = (x).to_repr(); + let tmp: &[u8] = rep.as_ref(); + let lower_128: u128 = u128::from_le_bytes(tmp[..16].try_into().unwrap()); + lower_128 as i128 + } +} + +#[cfg(test)] +mod test { + + use super::*; + use halo2curves::pasta::Fp as F; + + #[test] + fn test_conv() { + let res: F = i32_to_felt(-15i32); + assert_eq!(res, -F::from(15)); + + let res: F = i32_to_felt(2_i32.pow(17)); + assert_eq!(res, F::from(131072)); + + let res: F = i128_to_felt(-15i128); + assert_eq!(res, -F::from(15)); + + let res: F = i128_to_felt(2_i128.pow(17)); + assert_eq!(res, F::from(131072)); + } + + #[test] + fn felttoi32() { + for x in -(2i32.pow(16))..(2i32.pow(16)) { + let fieldx: F = i32_to_felt::(x); + let xf: i32 = felt_to_i32::(fieldx); + assert_eq!(x, xf); + } + } + + #[test] + fn felttoi128() { + for x in -(2i128.pow(20))..(2i128.pow(20)) { + let fieldx: F = i128_to_felt::(x); + let xf: i128 = felt_to_i128::(fieldx); + assert_eq!(x, xf); + } + } +} diff --git a/mnist_ezkl/src/graph/input.rs b/mnist_ezkl/src/graph/input.rs new file mode 100644 index 0000000..159a374 --- /dev/null +++ b/mnist_ezkl/src/graph/input.rs @@ -0,0 +1,662 @@ +use crate::circuit::InputType; +use crate::fieldutils::i128_to_felt; +#[cfg(not(target_arch = "wasm32"))] +use crate::tensor::Tensor; +use halo2curves::bn256::Fr as Fp; +#[cfg(not(target_arch = "wasm32"))] +use postgres::{Client, NoTls}; +#[cfg(feature = "python-bindings")] +use pyo3::prelude::*; +#[cfg(feature = "python-bindings")] +use pyo3::types::PyDict; +#[cfg(feature = "python-bindings")] +use pyo3::ToPyObject; +use serde::ser::SerializeStruct; +use serde::{Deserialize, Deserializer, Serialize, Serializer}; +use std::io::Read; +use std::panic::UnwindSafe; +#[cfg(not(target_arch = "wasm32"))] +use std::thread; +#[cfg(not(target_arch = "wasm32"))] +use tract_onnx::tract_hir::tract_num_traits::ToPrimitive; + +use super::quantize_float; +use super::GraphError; + +type Decimals = u8; +type Call = String; +type RPCUrl = String; + +/// +#[derive(Clone, Debug, PartialOrd, PartialEq)] +pub enum FileSourceInner { + /// Inner elements of float inputs coming from a file + Float(f64), + /// Inner elements of bool inputs coming from a file + Bool(bool), + /// Inner elements of inputs coming from a witness + Field(Fp), +} + +impl FileSourceInner { + /// + pub fn is_float(&self) -> bool { + matches!(self, FileSourceInner::Float(_)) + } + /// + pub fn is_bool(&self) -> bool { + matches!(self, FileSourceInner::Bool(_)) + } + /// + pub fn is_field(&self) -> bool { + matches!(self, FileSourceInner::Field(_)) + } +} + +impl Serialize for FileSourceInner { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + match self { + FileSourceInner::Field(data) => data.serialize(serializer), + FileSourceInner::Bool(data) => data.serialize(serializer), + FileSourceInner::Float(data) => data.serialize(serializer), + } + } +} + +// !!! ALWAYS USE JSON SERIALIZATION FOR GRAPH INPUT +// UNTAGGED ENUMS WONT WORK :( as highlighted here: +impl<'de> Deserialize<'de> for FileSourceInner { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + let this_json: Box = Deserialize::deserialize(deserializer)?; + + let bool_try: Result = serde_json::from_str(this_json.get()); + if let Ok(t) = bool_try { + return Ok(FileSourceInner::Bool(t)); + } + let float_try: Result = serde_json::from_str(this_json.get()); + if let Ok(t) = float_try { + return Ok(FileSourceInner::Float(t)); + } + let field_try: Result = serde_json::from_str(this_json.get()); + if let Ok(t) = field_try { + return Ok(FileSourceInner::Field(t)); + } + + Err(serde::de::Error::custom( + "failed to deserialize FileSourceInner", + )) + } +} + +/// Elements of inputs coming from a file +pub type FileSource = Vec>; + +impl FileSourceInner { + /// Create a new FileSourceInner + pub fn new_float(f: f64) -> Self { + FileSourceInner::Float(f) + } + /// Create a new FileSourceInner + pub fn new_field(f: Fp) -> Self { + FileSourceInner::Field(f) + } + /// Create a new FileSourceInner + pub fn new_bool(f: bool) -> Self { + FileSourceInner::Bool(f) + } + + /// + pub fn as_type(&mut self, input_type: &InputType) { + match self { + FileSourceInner::Float(f) => input_type.roundtrip(f), + FileSourceInner::Bool(_) => assert!(matches!(input_type, InputType::Bool)), + FileSourceInner::Field(_) => {} + } + } + + /// Convert to a field element + pub fn to_field(&self, scale: crate::Scale) -> Fp { + match self { + FileSourceInner::Float(f) => i128_to_felt(quantize_float(f, 0.0, scale).unwrap()), + FileSourceInner::Bool(f) => { + if *f { + Fp::one() + } else { + Fp::zero() + } + } + FileSourceInner::Field(f) => *f, + } + } + /// Convert to a float + pub fn to_float(&self) -> f64 { + match self { + FileSourceInner::Float(f) => *f, + FileSourceInner::Bool(f) => { + if *f { + 1.0 + } else { + 0.0 + } + } + FileSourceInner::Field(f) => crate::fieldutils::felt_to_i128(*f) as f64, + } + } +} + +/// Inner elements of inputs/outputs coming from on-chain +#[derive(Clone, Debug, Deserialize, Serialize, Default, PartialOrd, PartialEq)] +pub struct OnChainSource { + /// Vector of calls to accounts + pub calls: Vec, + /// RPC url + pub rpc: RPCUrl, +} + +impl OnChainSource { + /// Create a new OnChainSource + pub fn new(calls: Vec, rpc: RPCUrl) -> Self { + OnChainSource { calls, rpc } + } +} + +#[cfg(not(target_arch = "wasm32"))] +/// Inner elements of inputs/outputs coming from postgres DB +#[derive(Clone, Debug, Deserialize, Serialize, Default, PartialOrd, PartialEq)] +pub struct PostgresSource { + /// postgres host + pub host: RPCUrl, + /// user to connect to postgres + pub user: String, + /// password to connect to postgres + pub password: String, + /// query to execute + pub query: String, + /// dbname + pub dbname: String, + /// port + pub port: String, +} + +#[cfg(not(target_arch = "wasm32"))] +impl PostgresSource { + /// Create a new PostgresSource + pub fn new( + host: RPCUrl, + port: String, + user: String, + query: String, + dbname: String, + password: String, + ) -> Self { + PostgresSource { + host, + user, + password, + query, + dbname, + port, + } + } + + /// Fetch data from postgres + pub fn fetch(&self) -> Result>, Box> { + // clone to move into thread + let user = self.user.clone(); + let host = self.host.clone(); + let query = self.query.clone(); + let dbname = self.dbname.clone(); + let port = self.port.clone(); + let password = self.password.clone(); + + let config = if password.is_empty() { + format!( + "host={} user={} dbname={} port={}", + host, user, dbname, port + ) + } else { + format!( + "host={} user={} dbname={} port={} password={}", + host, user, dbname, port, password + ) + }; + + let res: Vec = thread::spawn(move || { + let mut client = Client::connect(&config, NoTls).unwrap(); + let mut res: Vec = Vec::new(); + // extract rows from query + for row in client.query(&query, &[]).unwrap() { + // extract features from row + for i in 0..row.len() { + res.push(row.get(i)); + } + } + res + }) + .join() + .map_err(|_| "failed to fetch data from postgres")?; + + Ok(vec![res]) + } + + /// Fetch data from postgres and format it as a FileSource + pub fn fetch_and_format_as_file( + &self, + ) -> Result>, Box> { + Ok(self + .fetch()? + .iter() + .map(|d| { + d.iter() + .map(|d| { + FileSourceInner::Float( + d.n.as_ref() + .unwrap() + .to_f64() + .ok_or("could not convert decimal to f64") + .unwrap(), + ) + }) + .collect() + }) + .collect()) + } +} + +impl OnChainSource { + #[cfg(not(target_arch = "wasm32"))] + /// Create dummy local on-chain data to test the OnChain data source + pub async fn test_from_file_data( + data: &FileSource, + scales: Vec, + mut shapes: Vec>, + rpc: Option<&str>, + ) -> Result<(Vec>, Self), Box> { + use crate::eth::{evm_quantize, read_on_chain_inputs, test_on_chain_data}; + use log::debug; + + // Set up local anvil instance for reading on-chain data + let (anvil, client) = crate::eth::setup_eth_backend(rpc, None).await?; + + let address = client.address(); + + let mut scales = scales; + // set scales to 1 where data is a field element + for (idx, i) in data.iter().enumerate() { + if i.iter().all(|e| e.is_field()) { + scales[idx] = 0; + shapes[idx] = vec![i.len()]; + } + } + + let calls_to_accounts = test_on_chain_data(client.clone(), data).await?; + debug!("Calls to accounts: {:?}", calls_to_accounts); + let inputs = read_on_chain_inputs(client.clone(), address, &calls_to_accounts).await?; + debug!("Inputs: {:?}", inputs); + + let mut quantized_evm_inputs = vec![]; + + let mut prev = 0; + for (idx, i) in data.iter().enumerate() { + quantized_evm_inputs.extend( + evm_quantize( + client.clone(), + vec![scales[idx]; i.len()], + &( + inputs.0[prev..i.len()].to_vec(), + inputs.1[prev..i.len()].to_vec(), + ), + ) + .await?, + ); + prev += i.len(); + } + + // on-chain data has already been quantized at this point. Just need to reshape it and push into tensor vector + let mut inputs: Vec> = vec![]; + for (input, shape) in [quantized_evm_inputs].iter().zip(shapes) { + let mut t: Tensor = input.iter().cloned().collect(); + t.reshape(&shape)?; + inputs.push(t); + } + + let used_rpc = rpc.unwrap_or(&anvil.endpoint()).to_string(); + + // Fill the input_data field of the GraphData struct + Ok(( + inputs, + OnChainSource::new(calls_to_accounts.clone(), used_rpc), + )) + } +} + +/// Defines the view only calls to accounts to fetch the on-chain input data. +/// This data will be included as part of the first elements in the publicInputs +/// for the sol evm verifier and will be verifyWithDataAttestation.sol +#[derive(Clone, Debug, Deserialize, Serialize, Default, PartialOrd, PartialEq)] +pub struct CallsToAccount { + /// A vector of tuples, where index 0 of tuples + /// are the byte strings representing the ABI encoded function calls to + /// read the data from the address. This call must return a single + /// elementary type (). + /// The second index of the tuple is the number of decimals for f32 conversion. + /// We don't support dynamic types currently. + pub call_data: Vec<(Call, Decimals)>, + /// Address of the contract to read the data from. + pub address: String, +} +/// Enum that defines source of the inputs/outputs to the EZKL model +#[derive(Clone, Debug, Serialize, PartialOrd, PartialEq)] +#[serde(untagged)] +pub enum DataSource { + /// .json File data source. + File(FileSource), + /// On-chain data source. The first element is the calls to the account, and the second is the RPC url. + OnChain(OnChainSource), + /// Postgres DB + #[cfg(not(target_arch = "wasm32"))] + DB(PostgresSource), +} + +impl Default for DataSource { + fn default() -> Self { + DataSource::File(vec![vec![]]) + } +} + +impl From for DataSource { + fn from(data: FileSource) -> Self { + DataSource::File(data) + } +} + +impl From>> for DataSource { + fn from(data: Vec>) -> Self { + DataSource::File( + data.iter() + .map(|e| e.iter().map(|e| FileSourceInner::Field(*e)).collect()) + .collect(), + ) + } +} + +impl From>> for DataSource { + fn from(data: Vec>) -> Self { + DataSource::File( + data.iter() + .map(|e| e.iter().map(|e| FileSourceInner::Float(*e)).collect()) + .collect(), + ) + } +} + +impl From for DataSource { + fn from(data: OnChainSource) -> Self { + DataSource::OnChain(data) + } +} + +// !!! ALWAYS USE JSON SERIALIZATION FOR GRAPH INPUT +// UNTAGGED ENUMS WONT WORK :( as highlighted here: +impl<'de> Deserialize<'de> for DataSource { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + let this_json: Box = Deserialize::deserialize(deserializer)?; + + let first_try: Result = serde_json::from_str(this_json.get()); + + if let Ok(t) = first_try { + return Ok(DataSource::File(t)); + } + let second_try: Result = serde_json::from_str(this_json.get()); + if let Ok(t) = second_try { + return Ok(DataSource::OnChain(t)); + } + #[cfg(not(target_arch = "wasm32"))] + { + let third_try: Result = serde_json::from_str(this_json.get()); + if let Ok(t) = third_try { + return Ok(DataSource::DB(t)); + } + } + + Err(serde::de::Error::custom("failed to deserialize DataSource")) + } +} + +/// Input to graph as a datasource +/// Always use JSON serialization for GraphData. Seriously. +#[derive(Clone, Debug, Deserialize, Default, PartialEq)] +pub struct GraphData { + /// Inputs to the model / computational graph (can be empty vectors if inputs are coming from on-chain). + pub input_data: DataSource, + /// Outputs of the model / computational graph (can be empty vectors if outputs are coming from on-chain). + pub output_data: Option, +} + +impl UnwindSafe for GraphData {} + +impl GraphData { + /// + pub fn new(input_data: DataSource) -> Self { + GraphData { + input_data, + output_data: None, + } + } + + /// Load the model input from a file + pub fn from_path(path: std::path::PathBuf) -> Result> { + let mut file = std::fs::File::open(path.clone()) + .map_err(|_| format!("failed to open input at {}", path.display()))?; + let mut data = String::new(); + file.read_to_string(&mut data)?; + serde_json::from_str(&data).map_err(|e| e.into()) + } + + /// Save the model input to a file + pub fn save(&self, path: std::path::PathBuf) -> Result<(), Box> { + serde_json::to_writer(std::fs::File::create(path)?, &self).map_err(|e| e.into()) + } + + /// + pub fn split_into_batches( + &self, + input_shapes: Vec>, + ) -> Result, Box> { + // split input data into batches + let mut batched_inputs = vec![]; + + let iterable = match self { + GraphData { + input_data: DataSource::File(data), + output_data: _, + } => data.clone(), + GraphData { + input_data: DataSource::OnChain(_), + output_data: _, + } => { + return Err(Box::new(GraphError::InvalidDims( + 0, + "on-chain data cannot be split into batches".to_string(), + ))) + } + #[cfg(not(target_arch = "wasm32"))] + GraphData { + input_data: DataSource::DB(data), + output_data: _, + } => data.fetch_and_format_as_file()?, + }; + + for (i, shape) in input_shapes.iter().enumerate() { + // ensure the input is evenly divisible by batch_size + let input_size = shape.clone().iter().product::(); + let input = &iterable[i]; + if input.len() % input_size != 0 { + return Err(Box::new(GraphError::InvalidDims( + 0, + "calibration data length must be evenly divisible by the original input_size" + .to_string(), + ))); + } + let mut batches = vec![]; + for batch in input.chunks(input_size) { + batches.push(batch.to_vec()); + } + batched_inputs.push(batches); + } + + // now merge all the batches for each input into a vector of batches + // first assert each input has the same number of batches + let num_batches = if batched_inputs.is_empty() { + 0 + } else { + let num_batches = batched_inputs[0].len(); + for input in batched_inputs.iter() { + assert_eq!(input.len(), num_batches); + } + num_batches + }; + // now merge the batches + let mut input_batches = vec![]; + for i in 0..num_batches { + let mut batch = vec![]; + for input in batched_inputs.iter() { + batch.push(input[i].clone()); + } + input_batches.push(DataSource::File(batch)); + } + + if input_batches.is_empty() { + input_batches.push(DataSource::File(vec![vec![]])); + } + + // create a new GraphWitness for each batch + let batches = input_batches + .into_iter() + .map(GraphData::new) + .collect::>(); + + Ok(batches) + } +} + +#[cfg(feature = "python-bindings")] +impl ToPyObject for CallsToAccount { + fn to_object(&self, py: Python) -> PyObject { + let dict = PyDict::new(py); + dict.set_item("account", &self.address).unwrap(); + dict.set_item("call_data", &self.call_data).unwrap(); + dict.to_object(py) + } +} + +#[cfg(feature = "python-bindings")] +impl ToPyObject for DataSource { + fn to_object(&self, py: Python) -> PyObject { + match self { + DataSource::File(data) => data.to_object(py), + DataSource::OnChain(source) => { + let dict = PyDict::new(py); + dict.set_item("rpc_url", &source.rpc).unwrap(); + dict.set_item("calls_to_accounts", &source.calls).unwrap(); + dict.to_object(py) + } + DataSource::DB(source) => { + let dict = PyDict::new(py); + dict.set_item("host", &source.host).unwrap(); + dict.set_item("user", &source.user).unwrap(); + dict.set_item("query", &source.query).unwrap(); + dict.to_object(py) + } + } + } +} + +#[cfg(feature = "python-bindings")] +use crate::pfsys::field_to_vecu64_montgomery; + +#[cfg(feature = "python-bindings")] +impl ToPyObject for FileSourceInner { + fn to_object(&self, py: Python) -> PyObject { + match self { + FileSourceInner::Field(data) => field_to_vecu64_montgomery(data).to_object(py), + FileSourceInner::Bool(data) => data.to_object(py), + FileSourceInner::Float(data) => data.to_object(py), + } + } +} + +impl Serialize for GraphData { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + let mut state = serializer.serialize_struct("GraphData", 4)?; + state.serialize_field("input_data", &self.input_data)?; + state.serialize_field("output_data", &self.output_data)?; + state.end() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + // this is for backwards compatibility with the old format + fn test_data_source_serialization_round_trip() { + let source = DataSource::from(vec![vec![0.053_262_424, 0.074_970_566, 0.052_355_476]]); + + let serialized = serde_json::to_string(&source).unwrap(); + + const JSON: &str = r#"[[0.053262424,0.074970566,0.052355476]]"#; + + assert_eq!(serialized, JSON); + + let expect = serde_json::from_str::(JSON) + .map_err(|e| e.to_string()) + .unwrap(); + + assert_eq!(expect, source); + } + + #[test] + // this is for backwards compatibility with the old format + fn test_graph_input_serialization_round_trip() { + let file = GraphData::new(DataSource::from(vec![vec![ + 0.05326242372393608, + 0.07497056573629379, + 0.05235547572374344, + ]])); + + let serialized = serde_json::to_string(&file).unwrap(); + + const JSON: &str = r#"{"input_data":[[0.05326242372393608,0.07497056573629379,0.05235547572374344]],"output_data":null}"#; + + assert_eq!(serialized, JSON); + + let graph_input3 = serde_json::from_str::(JSON) + .map_err(|e| e.to_string()) + .unwrap(); + assert_eq!(graph_input3, file); + } + + // test for the compatibility with the serialized elements from the mclbn256 library + #[test] + fn test_python_compat() { + let source = Fp::from_raw([18445520602771460712, 838677322461845011, 3079992810, 0]); + + let original_addr = "0x000000000000000000000000b794f5ea0ba39494ce839613fffba74279579268"; + + assert_eq!(format!("{:?}", source), original_addr); + } +} diff --git a/mnist_ezkl/src/graph/mod.rs b/mnist_ezkl/src/graph/mod.rs new file mode 100644 index 0000000..769a7e2 --- /dev/null +++ b/mnist_ezkl/src/graph/mod.rs @@ -0,0 +1,1542 @@ +/// Representations of a computational graph's inputs. +pub mod input; +/// Crate for defining a computational graph and building a ZK-circuit from it. +pub mod model; +/// Representations of a computational graph's modules. +pub mod modules; +/// Inner elements of a computational graph that represent a single operation / constraints. +pub mod node; +/// Helper functions +pub mod utilities; +/// Representations of a computational graph's variables. +pub mod vars; +#[cfg(not(target_arch = "wasm32"))] +use colored_json::ToColoredJson; +use halo2_proofs::plonk::VerifyingKey; +use halo2_proofs::poly::kzg::commitment::ParamsKZG; +pub use input::DataSource; +use itertools::Itertools; + +#[cfg(not(target_arch = "wasm32"))] +use self::input::OnChainSource; +use self::input::{FileSource, GraphData}; +use self::modules::{ + GraphModules, ModuleConfigs, ModuleForwardResult, ModuleSettings, ModuleSizes, +}; +use crate::circuit::lookup::LookupOp; +use crate::circuit::modules::ModulePlanner; +use crate::circuit::table::{Table, RESERVED_BLINDING_ROWS_PAD}; +use crate::circuit::{CheckMode, InputType}; +use crate::tensor::{Tensor, ValTensor}; +use crate::RunArgs; +use halo2_proofs::{ + circuit::Layouter, + plonk::{Circuit, ConstraintSystem, Error as PlonkError}, +}; +use halo2curves::bn256::{self, Bn256, Fr as Fp, G1Affine}; +use halo2curves::ff::PrimeField; +use log::{debug, error, info, trace, warn}; +pub use model::*; +pub use node::*; +#[cfg(feature = "python-bindings")] +use pyo3::prelude::*; +#[cfg(feature = "python-bindings")] +use pyo3::types::PyDict; +#[cfg(feature = "python-bindings")] +use pyo3::ToPyObject; +use rayon::prelude::{IntoParallelRefIterator, ParallelIterator}; +use serde::{Deserialize, Serialize}; +use std::io::{Read, Write}; +use std::ops::Deref; +use thiserror::Error; +pub use utilities::*; +pub use vars::*; + +#[cfg(feature = "python-bindings")] +use crate::pfsys::field_to_vecu64_montgomery; + +/// The safety factor for the range of the lookup table. +pub const RANGE_MULTIPLIER: i128 = 2; + +/// Max representation of a lookup table input +pub const MAX_LOOKUP_ABS: i128 = 8 * 2_i128.pow(MAX_PUBLIC_SRS); + +/// circuit related errors. +#[derive(Debug, Error)] +pub enum GraphError { + /// The wrong inputs were passed to a lookup node + #[error("invalid inputs for a lookup node")] + InvalidLookupInputs, + /// Shape mismatch in circuit construction + #[error("invalid dimensions used for node {0} ({1})")] + InvalidDims(usize, String), + /// Wrong method was called to configure an op + #[error("wrong method was called to configure node {0} ({1})")] + WrongMethod(usize, String), + /// A requested node is missing in the graph + #[error("a requested node is missing in the graph: {0}")] + MissingNode(usize), + /// The wrong method was called on an operation + #[error("an unsupported method was called on node {0} ({1})")] + OpMismatch(usize, String), + /// This operation is unsupported + #[error("unsupported operation in graph")] + UnsupportedOp, + /// This operation is unsupported + #[error("unsupported datatype in graph")] + UnsupportedDataType, + /// A node has missing parameters + #[error("a node is missing required params: {0}")] + MissingParams(String), + /// A node has missing parameters + #[error("a node is has misformed params: {0}")] + MisformedParams(String), + /// Error in the configuration of the visibility of variables + #[error("there should be at least one set of public variables")] + Visibility, + /// Ezkl only supports divisions by constants + #[error("ezkl currently only supports division by constants")] + NonConstantDiv, + /// Ezkl only supports constant powers + #[error("ezkl currently only supports constant exponents")] + NonConstantPower, + /// Error when attempting to rescale an operation + #[error("failed to rescale inputs for {0}")] + RescalingError(String), + /// Error when attempting to load a model + #[error("failed to load model")] + ModelLoad, + /// Packing exponent is too large + #[error("largest packing exponent exceeds max. try reducing the scale")] + PackingExponent, + /// Invalid Input Types + #[error("invalid input types")] + InvalidInputTypes, + /// Missing results + #[error("missing results")] + MissingResults, +} + +const ASSUMED_BLINDING_FACTORS: usize = 5; +/// The minimum number of rows in the grid +pub const MIN_LOGROWS: u32 = 6; + +/// 26 +pub const MAX_PUBLIC_SRS: u32 = bn256::Fr::S - 2; + +/// Lookup deg +pub const LOOKUP_DEG: usize = 5; + +use std::cell::RefCell; + +thread_local!( + /// This is a global variable that holds the settings for the graph + /// This is used to pass settings to the layouter and other parts of the circuit without needing to heavily modify the Halo2 API in a new fork + pub static GLOBAL_SETTINGS: RefCell> = RefCell::new(None) +); + +/// Result from a forward pass +#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, Eq)] +pub struct GraphWitness { + /// The inputs of the forward pass + pub inputs: Vec>, + /// The output of the forward pass + pub outputs: Vec>, + /// Any hashes of inputs generated during the forward pass + pub processed_inputs: Option, + /// Any hashes of params generated during the forward pass + pub processed_params: Option, + /// Any hashes of outputs generated during the forward pass + pub processed_outputs: Option, + /// max lookup input + pub max_lookup_inputs: i128, + /// max lookup input + pub min_lookup_inputs: i128, +} + +impl GraphWitness { + /// + pub fn new(inputs: Vec>, outputs: Vec>) -> Self { + GraphWitness { + inputs, + outputs, + processed_inputs: None, + processed_params: None, + processed_outputs: None, + max_lookup_inputs: 0, + min_lookup_inputs: 0, + } + } + + /// + pub fn get_kzg_commitments(&self) -> Vec { + let mut commitments = vec![]; + if let Some(processed_inputs) = &self.processed_inputs { + if let Some(commits) = &processed_inputs.kzg_commit { + commitments.extend(commits.iter().flatten()); + } + } + if let Some(processed_params) = &self.processed_params { + if let Some(commits) = &processed_params.kzg_commit { + commitments.extend(commits.iter().flatten()); + } + } + if let Some(processed_outputs) = &self.processed_outputs { + if let Some(commits) = &processed_outputs.kzg_commit { + commitments.extend(commits.iter().flatten()); + } + } + commitments + } + + /// Export the ezkl witness as json + pub fn as_json(&self) -> Result> { + let serialized = match serde_json::to_string(&self) { + Ok(s) => s, + Err(e) => { + return Err(Box::new(e)); + } + }; + Ok(serialized) + } + + /// Load the model input from a file + pub fn from_path(path: std::path::PathBuf) -> Result> { + let mut file = std::fs::File::open(path.clone()) + .map_err(|_| format!("failed to load model at {}", path.display()))?; + let mut data = String::new(); + file.read_to_string(&mut data)?; + serde_json::from_str(&data).map_err(|e| e.into()) + } + + /// Save the model input to a file + pub fn save(&self, path: std::path::PathBuf) -> Result<(), Box> { + serde_json::to_writer(std::fs::File::create(path)?, &self).map_err(|e| e.into()) + } + + /// + pub fn get_input_tensor(&self) -> Vec> { + self.inputs + .clone() + .into_iter() + .map(|i| Tensor::from(i.into_iter())) + .collect::>>() + } + + /// + pub fn get_output_tensor(&self) -> Vec> { + self.outputs + .clone() + .into_iter() + .map(|i| Tensor::from(i.into_iter())) + .collect::>>() + } +} + +#[cfg(feature = "python-bindings")] +impl ToPyObject for GraphWitness { + fn to_object(&self, py: Python) -> PyObject { + // Create a Python dictionary + let dict = PyDict::new(py); + let dict_inputs = PyDict::new(py); + let dict_params = PyDict::new(py); + let dict_outputs = PyDict::new(py); + + let inputs: Vec> = self + .inputs + .iter() + .map(|x| x.iter().map(field_to_vecu64_montgomery).collect()) + .collect(); + + let outputs: Vec> = self + .outputs + .iter() + .map(|x| x.iter().map(field_to_vecu64_montgomery).collect()) + .collect(); + + dict.set_item("inputs", &inputs).unwrap(); + dict.set_item("outputs", &outputs).unwrap(); + dict.set_item("max_lookup_inputs", &self.max_lookup_inputs) + .unwrap(); + + if let Some(processed_inputs) = &self.processed_inputs { + //poseidon_hash + if let Some(processed_inputs_poseidon_hash) = &processed_inputs.poseidon_hash { + insert_poseidon_hash_pydict(&dict_inputs, &processed_inputs_poseidon_hash).unwrap(); + } + if let Some(processed_inputs_elgamal) = &processed_inputs.elgamal { + insert_elgamal_results_pydict(py, dict_inputs, processed_inputs_elgamal).unwrap(); + } + if let Some(processed_inputs_kzg_commit) = &processed_inputs.kzg_commit { + insert_kzg_commit_pydict(&dict_inputs, &processed_inputs_kzg_commit).unwrap(); + } + + dict.set_item("processed_inputs", dict_inputs).unwrap(); + } + + if let Some(processed_params) = &self.processed_params { + if let Some(processed_params_poseidon_hash) = &processed_params.poseidon_hash { + insert_poseidon_hash_pydict(dict_params, &processed_params_poseidon_hash).unwrap(); + } + if let Some(processed_params_elgamal) = &processed_params.elgamal { + insert_elgamal_results_pydict(py, dict_params, processed_params_elgamal).unwrap(); + } + if let Some(processed_params_kzg_commit) = &processed_params.kzg_commit { + insert_kzg_commit_pydict(&dict_inputs, &processed_params_kzg_commit).unwrap(); + } + + dict.set_item("processed_params", dict_params).unwrap(); + } + + if let Some(processed_outputs) = &self.processed_outputs { + if let Some(processed_outputs_poseidon_hash) = &processed_outputs.poseidon_hash { + insert_poseidon_hash_pydict(dict_outputs, &processed_outputs_poseidon_hash) + .unwrap(); + } + if let Some(processed_outputs_elgamal) = &processed_outputs.elgamal { + insert_elgamal_results_pydict(py, dict_outputs, processed_outputs_elgamal).unwrap(); + } + if let Some(processed_outputs_kzg_commit) = &processed_outputs.kzg_commit { + insert_kzg_commit_pydict(&dict_inputs, &processed_outputs_kzg_commit).unwrap(); + } + + dict.set_item("processed_outputs", dict_outputs).unwrap(); + } + + dict.to_object(py) + } +} + +#[cfg(feature = "python-bindings")] +fn insert_poseidon_hash_pydict(pydict: &PyDict, poseidon_hash: &Vec) -> Result<(), PyErr> { + let poseidon_hash: Vec<[u64; 4]> = poseidon_hash + .iter() + .map(field_to_vecu64_montgomery) + .collect(); + pydict.set_item("poseidon_hash", poseidon_hash)?; + + Ok(()) +} + +#[cfg(feature = "python-bindings")] +fn insert_kzg_commit_pydict(pydict: &PyDict, commits: &Vec>) -> Result<(), PyErr> { + use crate::python::PyG1Affine; + let poseidon_hash: Vec> = commits + .iter() + .map(|c| c.iter().map(|x| PyG1Affine::from(*x)).collect()) + .collect(); + pydict.set_item("kzg_commit", poseidon_hash)?; + + Ok(()) +} + +#[cfg(feature = "python-bindings")] +use modules::ElGamalResult; +#[cfg(feature = "python-bindings")] +fn insert_elgamal_results_pydict( + py: Python, + pydict: &PyDict, + elgamal_results: &ElGamalResult, +) -> Result<(), PyErr> { + let results_dict = PyDict::new(py); + let cipher_text: Vec> = elgamal_results + .ciphertexts + .iter() + .map(|v| { + v.iter() + .map(field_to_vecu64_montgomery) + .collect::>() + }) + .collect::>>(); + results_dict.set_item("ciphertexts", cipher_text)?; + + let encrypted_messages: Vec> = elgamal_results + .encrypted_messages + .iter() + .map(|v| { + v.iter() + .map(field_to_vecu64_montgomery) + .collect::>() + }) + .collect::>>(); + results_dict.set_item("encrypted_messages", encrypted_messages)?; + + let variables: crate::python::PyElGamalVariables = elgamal_results.variables.clone().into(); + + results_dict.set_item("variables", variables)?; + + pydict.set_item("elgamal", results_dict)?; + + Ok(()) + + //elgamal +} + +/// model parameters +#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)] +pub struct GraphSettings { + /// run args + pub run_args: RunArgs, + /// the potential number of rows used by the circuit + pub num_rows: usize, + /// total linear coordinate of assignments + pub total_assignments: usize, + /// total const size + pub total_const_size: usize, + /// the shape of public inputs to the model (in order of appearance) + pub model_instance_shapes: Vec>, + /// model output scales + pub model_output_scales: Vec, + /// model input scales + pub model_input_scales: Vec, + /// the of instance cells used by modules + pub module_sizes: ModuleSizes, + /// required_lookups + pub required_lookups: Vec, + /// check mode + pub check_mode: CheckMode, + /// ezkl version used + pub version: String, + /// num blinding factors + pub num_blinding_factors: Option, +} + +impl GraphSettings { + /// calculate the total number of instances + pub fn total_instances(&self) -> Vec { + let mut instances: Vec = self + .model_instance_shapes + .iter() + .map(|x| x.iter().product()) + .collect(); + instances.extend(self.module_sizes.num_instances()); + + instances + } + + /// save params to file + pub fn save(&self, path: &std::path::PathBuf) -> Result<(), std::io::Error> { + let encoded = serde_json::to_string(&self)?; + let mut file = std::fs::File::create(path)?; + file.write_all(encoded.as_bytes()) + } + /// load params from file + pub fn load(path: &std::path::PathBuf) -> Result { + let mut file = std::fs::File::open(path).map_err(|e| { + error!("failed to open settings file at {}", e); + e + })?; + let mut data = String::new(); + file.read_to_string(&mut data)?; + let res = serde_json::from_str(&data)?; + Ok(res) + } + + /// Export the ezkl configuration as json + pub fn as_json(&self) -> Result> { + let serialized = match serde_json::to_string(&self) { + Ok(s) => s, + Err(e) => { + return Err(Box::new(e)); + } + }; + Ok(serialized) + } + /// Parse an ezkl configuration from a json + pub fn from_json(arg_json: &str) -> Result { + serde_json::from_str(arg_json) + } + + fn set_num_blinding_factors(&mut self, num_blinding_factors: usize) { + self.num_blinding_factors = Some(num_blinding_factors); + } + + /// + pub fn available_col_size(&self) -> usize { + let base = 2u32; + if let Some(num_blinding_factors) = self.num_blinding_factors { + base.pow(self.run_args.logrows) as usize - num_blinding_factors - 1 + } else { + log::error!("num_blinding_factors not set"); + log::warn!("using default available_col_size"); + base.pow(self.run_args.logrows) as usize - ASSUMED_BLINDING_FACTORS - 1 + } + } + + /// + pub fn uses_modules(&self) -> bool { + !self.module_sizes.max_constraints() > 0 + } + + /// if any visibility is encrypted or hashed + pub fn module_requires_fixed(&self) -> bool { + if self.run_args.input_visibility.is_encrypted() + || self.run_args.input_visibility.is_hashed() + || self.run_args.output_visibility.is_encrypted() + || self.run_args.output_visibility.is_hashed() + || self.run_args.param_visibility.is_encrypted() + || self.run_args.param_visibility.is_hashed() + { + true + } else { + false + } + } +} + +/// Configuration for a computational graph / model loaded from a `.onnx` file. +#[derive(Clone, Debug)] +pub struct GraphConfig { + model_config: ModelConfig, + module_configs: ModuleConfigs, +} + +/// Defines the circuit for a computational graph / model loaded from a `.onnx` file. +#[derive(Clone, Debug, Default, Serialize, Deserialize)] +pub struct CoreCircuit { + /// The model / graph of computations. + pub model: Model, + /// The settings of the model. + pub settings: GraphSettings, +} + +/// Defines the circuit for a computational graph / model loaded from a `.onnx` file. +#[derive(Clone, Debug, Default, Serialize, Deserialize)] +pub struct GraphCircuit { + /// Core circuit + pub core: CoreCircuit, + /// The witness data for the model. + pub graph_witness: GraphWitness, + /// The settings of the model's modules. + pub module_settings: ModuleSettings, +} + +impl GraphCircuit { + /// Settings for the graph + pub fn settings(&self) -> &GraphSettings { + &self.core.settings + } + /// Settings for the graph (mutable) + pub fn settings_mut(&mut self) -> &mut GraphSettings { + &mut self.core.settings + } + /// The model + pub fn model(&self) -> &Model { + &self.core.model + } + /// + pub fn save(&self, path: std::path::PathBuf) -> Result<(), Box> { + let f = std::fs::File::create(path)?; + let writer = std::io::BufWriter::new(f); + bincode::serialize_into(writer, &self)?; + Ok(()) + } + + /// + pub fn load(path: std::path::PathBuf) -> Result> { + // read bytes from file + let mut f = std::fs::File::open(&path)?; + let metadata = std::fs::metadata(&path)?; + let mut buffer = vec![0; metadata.len() as usize]; + f.read_exact(&mut buffer)?; + let result = bincode::deserialize(&buffer)?; + Ok(result) + } +} + +#[derive(Clone, Debug, Default, Deserialize, Serialize, PartialEq, PartialOrd)] +/// The data source for a test +pub enum TestDataSource { + /// The data is loaded from a file + File, + /// The data is loaded from the chain + #[default] + OnChain, +} + +impl From for TestDataSource { + fn from(value: String) -> Self { + match value.to_lowercase().as_str() { + "file" => TestDataSource::File, + "on-chain" => TestDataSource::OnChain, + _ => { + error!("invalid data source: {}", value); + warn!("using default data source: on-chain"); + TestDataSource::default() + } + } + } +} + +#[derive(Clone, Debug, Default)] +/// +pub struct TestSources { + /// + pub input: TestDataSource, + /// + pub output: TestDataSource, +} + +/// +#[derive(Clone, Debug, Default)] +pub struct TestOnChainData { + /// The path to the test witness + pub data: std::path::PathBuf, + /// rpc endpoint + pub rpc: Option, + /// + pub data_sources: TestSources, +} + +impl GraphCircuit { + /// + pub fn new( + model: Model, + run_args: &RunArgs, + ) -> Result> { + // // placeholder dummy inputs - must call prepare_public_inputs to load data afterwards + let mut inputs: Vec> = vec![]; + for shape in model.graph.input_shapes()? { + let t: Vec = vec![Fp::zero(); shape.iter().product::()]; + inputs.push(t); + } + + // dummy module settings, must load from GraphData after + let module_settings = ModuleSettings::default(); + let mut settings = model.gen_params(run_args, CheckMode::UNSAFE)?; + + let mut num_params = 0; + if !model.const_shapes().is_empty() { + for shape in model.const_shapes() { + num_params += shape.iter().product::(); + } + } + + let sizes = GraphModules::num_constraints_and_instances( + model.graph.input_shapes()?, + vec![vec![num_params]], + model.graph.output_shapes()?, + VarVisibility::from_args(run_args)?, + ); + + // number of instances used by modules + settings.module_sizes = sizes.clone(); + + // as they occupy independent rows + settings.num_rows = std::cmp::max(settings.num_rows, sizes.max_constraints()); + + let core = CoreCircuit { + model, + settings: settings.clone(), + }; + + Ok(GraphCircuit { + core, + graph_witness: GraphWitness::new(inputs, vec![]), + module_settings, + }) + } + + /// + pub fn new_from_settings( + model: Model, + mut settings: GraphSettings, + check_mode: CheckMode, + ) -> Result> { + // placeholder dummy inputs - must call prepare_public_inputs to load data afterwards + let mut inputs: Vec> = vec![]; + for shape in model.graph.input_shapes()? { + let t: Vec = vec![Fp::zero(); shape.iter().product::()]; + inputs.push(t); + } + + // dummy module settings, must load from GraphData after + let module_settings = ModuleSettings::default(); + + settings.check_mode = check_mode; + + let core = CoreCircuit { + model, + settings: settings.clone(), + }; + + Ok(GraphCircuit { + core, + graph_witness: GraphWitness::new(inputs, vec![]), + module_settings, + }) + } + + /// load inputs and outputs for the model + pub fn load_graph_witness( + &mut self, + data: &GraphWitness, + ) -> Result<(), Box> { + self.graph_witness = data.clone(); + // load the module settings + self.module_settings = ModuleSettings::from(data); + + Ok(()) + } + + /// Prepare the public inputs for the circuit. + pub fn prepare_public_inputs( + &mut self, + data: &GraphWitness, + ) -> Result, Box> { + // quantize the supplied data using the provided scale. + // the ordering here is important, we want the inputs to come before the outputs + // as they are configured in that order as Column + let mut public_inputs: Vec = vec![]; + if self.settings().run_args.input_visibility.is_public() { + public_inputs.extend(self.graph_witness.inputs.clone().into_iter().flatten()) + } else if let Some(processed_inputs) = &data.processed_inputs { + public_inputs.extend(processed_inputs.get_instances().into_iter().flatten()); + } + + if let Some(processed_params) = &data.processed_params { + public_inputs.extend(processed_params.get_instances().into_iter().flatten()); + } + + if self.settings().run_args.output_visibility.is_public() { + public_inputs.extend(self.graph_witness.outputs.clone().into_iter().flatten()); + } else if let Some(processed_outputs) = &data.processed_outputs { + public_inputs.extend(processed_outputs.get_instances().into_iter().flatten()); + } + + debug!("public inputs: {:?}", public_inputs); + + Ok(public_inputs) + } + + /// + #[cfg(target_arch = "wasm32")] + pub fn load_graph_input( + &mut self, + data: &GraphData, + ) -> Result>, Box> { + let shapes = self.model().graph.input_shapes()?; + let scales = self.model().graph.get_input_scales(); + let input_types = self.model().graph.get_input_types()?; + self.process_data_source(&data.input_data, shapes, scales, input_types) + } + + /// + pub fn load_graph_from_file_exclusively( + &mut self, + data: &GraphData, + ) -> Result>, Box> { + let shapes = self.model().graph.input_shapes()?; + let scales = self.model().graph.get_input_scales(); + let input_types = self.model().graph.get_input_types()?; + info!("input scales: {:?}", scales); + + match &data.input_data { + DataSource::File(file_data) => { + self.load_file_data(file_data, &shapes, scales, input_types) + } + _ => Err("Cannot use non-file data source as input for this method.".into()), + } + } + + /// + #[cfg(not(target_arch = "wasm32"))] + pub async fn load_graph_input( + &mut self, + data: &GraphData, + ) -> Result>, Box> { + let shapes = self.model().graph.input_shapes()?; + let scales = self.model().graph.get_input_scales(); + let input_types = self.model().graph.get_input_types()?; + info!("input scales: {:?}", scales); + + self.process_data_source(&data.input_data, shapes, scales, input_types) + .await + } + + #[cfg(target_arch = "wasm32")] + /// Process the data source for the model + fn process_data_source( + &mut self, + data: &DataSource, + shapes: Vec>, + scales: Vec, + input_types: Vec, + ) -> Result>, Box> { + match &data { + DataSource::File(file_data) => { + self.load_file_data(file_data, &shapes, scales, input_types) + } + DataSource::OnChain(_) => { + Err("Cannot use on-chain data source as input for this method.".into()) + } + } + } + + #[cfg(not(target_arch = "wasm32"))] + /// Process the data source for the model + async fn process_data_source( + &mut self, + data: &DataSource, + shapes: Vec>, + scales: Vec, + input_types: Vec, + ) -> Result>, Box> { + match &data { + DataSource::OnChain(source) => { + let mut per_item_scale = vec![]; + for (i, shape) in shapes.iter().enumerate() { + per_item_scale.extend(vec![scales[i]; shape.iter().product::()]); + } + self.load_on_chain_data(source.clone(), &shapes, per_item_scale) + .await + } + DataSource::File(file_data) => { + self.load_file_data(file_data, &shapes, scales, input_types) + } + DataSource::DB(pg) => { + let data = pg.fetch_and_format_as_file()?; + self.load_file_data(&data, &shapes, scales, input_types) + } + } + } + + /// Prepare on chain test data + #[cfg(not(target_arch = "wasm32"))] + pub async fn load_on_chain_data( + &mut self, + source: OnChainSource, + shapes: &Vec>, + scales: Vec, + ) -> Result>, Box> { + use crate::eth::{evm_quantize, read_on_chain_inputs, setup_eth_backend}; + let (_, client) = setup_eth_backend(Some(&source.rpc), None).await?; + let inputs = read_on_chain_inputs(client.clone(), client.address(), &source.calls).await?; + // quantize the supplied data using the provided scale + QuantizeData.sol + let quantized_evm_inputs = evm_quantize(client, scales, &inputs).await?; + // on-chain data has already been quantized at this point. Just need to reshape it and push into tensor vector + let mut inputs: Vec> = vec![]; + for (input, shape) in [quantized_evm_inputs].iter().zip(shapes) { + let mut t: Tensor = input.iter().cloned().collect(); + t.reshape(shape)?; + inputs.push(t); + } + + Ok(inputs) + } + + /// + pub fn load_file_data( + &mut self, + file_data: &FileSource, + shapes: &Vec>, + scales: Vec, + input_types: Vec, + ) -> Result>, Box> { + // quantize the supplied data using the provided scale. + let mut data: Vec> = vec![]; + for (((d, shape), scale), input_type) in file_data + .iter() + .zip(shapes) + .zip(scales) + .zip(input_types.iter()) + { + let t: Vec = d + .par_iter() + .map(|x| { + let mut x = x.clone(); + x.as_type(input_type); + x.to_field(scale) + }) + .collect(); + + let mut t: Tensor = t.into_iter().into(); + t.reshape(shape)?; + + data.push(t); + } + Ok(data) + } + + /// + pub fn load_witness_file_data( + &mut self, + file_data: &[Vec], + shapes: &[Vec], + ) -> Result>, Box> { + // quantize the supplied data using the provided scale. + let mut data: Vec> = vec![]; + for (d, shape) in file_data.iter().zip(shapes) { + let mut t: Tensor = d.clone().into_iter().into(); + t.reshape(shape)?; + data.push(t); + } + Ok(data) + } + + fn reserved_blinding_rows() -> f64 { + (ASSUMED_BLINDING_FACTORS + RESERVED_BLINDING_ROWS_PAD) as f64 + } + + fn calc_safe_range(res: &GraphWitness) -> (i128, i128) { + ( + RANGE_MULTIPLIER * res.min_lookup_inputs, + RANGE_MULTIPLIER * res.max_lookup_inputs, + ) + } + + fn calc_num_cols(safe_range: (i128, i128), max_logrows: u32) -> usize { + let max_col_size = Table::::cal_col_size( + max_logrows as usize, + Self::reserved_blinding_rows() as usize, + ); + Table::::num_cols_required(safe_range, max_col_size) + } + + fn calc_min_logrows( + &mut self, + res: &GraphWitness, + max_logrows: Option, + ) -> Result<(), Box> { + // load the max logrows + let max_logrows = max_logrows.unwrap_or(MAX_PUBLIC_SRS); + let max_logrows = std::cmp::min(max_logrows, MAX_PUBLIC_SRS); + let mut max_logrows = std::cmp::max(max_logrows, MIN_LOGROWS); + + let reserved_blinding_rows = Self::reserved_blinding_rows(); + // check if has overflowed max lookup input + if res.max_lookup_inputs > MAX_LOOKUP_ABS / RANGE_MULTIPLIER + || res.min_lookup_inputs < -MAX_LOOKUP_ABS / RANGE_MULTIPLIER + { + let err_string = format!("max lookup input ({}) is too large", res.max_lookup_inputs); + return Err(err_string.into()); + } + + let safe_range = Self::calc_safe_range(res); + let mut min_logrows = MIN_LOGROWS; + // degrade the max logrows until the extended k is small enough + while min_logrows < max_logrows + && !self.extended_k_is_small_enough( + min_logrows, + Self::calc_num_cols(safe_range, min_logrows), + ) + { + min_logrows += 1; + } + + if !self + .extended_k_is_small_enough(min_logrows, Self::calc_num_cols(safe_range, min_logrows)) + { + let err_string = format!( + "extended k is too large to accomodate the quotient polynomial with logrows {}", + min_logrows + ); + return Err(err_string.into()); + } + + // degrade the max logrows until the extended k is small enough + while max_logrows > min_logrows + && !self.extended_k_is_small_enough( + max_logrows, + Self::calc_num_cols(safe_range, max_logrows), + ) + { + max_logrows -= 1; + } + + let min_bits = ((safe_range.1 - safe_range.0) as f64 + reserved_blinding_rows + 1.) + .log2() + .ceil() as usize; + + let min_rows_from_constraints = (self.settings().num_rows as f64 + reserved_blinding_rows) + .log2() + .ceil() as usize; + + let mut logrows = std::cmp::max(min_bits, min_rows_from_constraints); + + // if public input then public inputs col will have public inputs len + if self.settings().run_args.input_visibility.is_public() + || self.settings().run_args.output_visibility.is_public() + { + let mut max_instance_len = self + .model() + .instance_shapes()? + .iter() + .fold(0, |acc, x| std::cmp::max(acc, x.iter().product::())) + as f64 + + reserved_blinding_rows; + // if there are modules then we need to add the max module size + if self.settings().uses_modules() { + max_instance_len += self + .settings() + .module_sizes + .num_instances() + .iter() + .sum::() as f64; + } + let instance_len_logrows = (max_instance_len).log2().ceil() as usize; + logrows = std::cmp::max(logrows, instance_len_logrows); + // this is for fixed const columns + } + + // ensure logrows is at least 4 + logrows = std::cmp::max(logrows, min_logrows as usize); + logrows = std::cmp::min(logrows, max_logrows as usize); + + let model = self.model().clone(); + let settings_mut = self.settings_mut(); + settings_mut.run_args.lookup_range = safe_range; + settings_mut.run_args.logrows = logrows as u32; + + *settings_mut = GraphCircuit::new(model, &settings_mut.run_args)? + .settings() + .clone(); + + // recalculate the total const size give nthe new logrows + let total_const_len = settings_mut.total_const_size; + let const_len_logrows = (total_const_len as f64).log2().ceil() as u32; + settings_mut.run_args.logrows = + std::cmp::max(settings_mut.run_args.logrows, const_len_logrows); + // recalculate the total number of constraints given the new logrows + let min_rows_from_constraints = (settings_mut.num_rows as f64 + reserved_blinding_rows) + .log2() + .ceil() as u32; + settings_mut.run_args.logrows = + std::cmp::max(settings_mut.run_args.logrows, min_rows_from_constraints); + + settings_mut.run_args.logrows = std::cmp::min(max_logrows, settings_mut.run_args.logrows); + + info!( + "setting lookup_range to: {:?}, setting logrows to: {}", + self.settings().run_args.lookup_range, + self.settings().run_args.logrows + ); + + Ok(()) + } + + fn extended_k_is_small_enough(&self, k: u32, num_lookup_cols: usize) -> bool { + let max_degree = self.settings().run_args.num_inner_cols + 2; + let max_lookup_degree = LOOKUP_DEG + num_lookup_cols - 1; // num_lookup_cols - 1 is the degree of the lookup synthetic selector + + let max_degree = std::cmp::max(max_degree, max_lookup_degree); + // quotient_poly_degree * params.n - 1 is the degree of the quotient polynomial + let quotient_poly_degree = (max_degree - 1) as u64; + // n = 2^k + let n = 1u64 << k; + let mut extended_k = k; + while (1 << extended_k) < (n * quotient_poly_degree) { + extended_k += 1; + } + extended_k <= bn256::Fr::S + } + + /// Calibrate the circuit to the supplied data. + pub fn calibrate( + &mut self, + input: &[Tensor], + max_logrows: Option, + ) -> Result<(), Box> { + let res = self.forward(&mut input.to_vec(), None, None)?; + self.calc_min_logrows(&res, max_logrows) + } + + /// Runs the forward pass of the model / graph of computations and any associated hashing. + pub fn forward( + &self, + inputs: &mut [Tensor], + vk: Option<&VerifyingKey>, + srs: Option<&ParamsKZG>, + ) -> Result> { + let original_inputs = inputs.to_vec(); + + let visibility = VarVisibility::from_args(&self.settings().run_args)?; + let mut processed_inputs = None; + let mut processed_params = None; + let mut processed_outputs = None; + + if visibility.input.requires_processing() { + let module_outlets = visibility.input.overwrites_inputs(); + if !module_outlets.is_empty() { + let mut module_inputs = vec![]; + for outlet in &module_outlets { + module_inputs.push(inputs[*outlet].clone()); + } + let res = GraphModules::forward(&module_inputs, visibility.input.clone(), vk, srs)?; + processed_inputs = Some(res.clone()); + let module_results = res.get_result(visibility.input.clone()); + + for (i, outlet) in module_outlets.iter().enumerate() { + inputs[*outlet] = Tensor::from(module_results[i].clone().into_iter()); + } + } else { + processed_inputs = Some(GraphModules::forward(inputs, visibility.input, vk, srs)?); + } + } + + if visibility.params.requires_processing() { + let params = self.model().get_all_params(); + if !params.is_empty() { + let flattened_params = Tensor::new(Some(¶ms), &[params.len()])?.combine()?; + processed_params = Some(GraphModules::forward( + &[flattened_params], + visibility.params, + vk, + srs, + )?); + } + } + + let mut model_results = self.model().forward(inputs)?; + + if visibility.output.requires_processing() { + let module_outlets = visibility.output.overwrites_inputs(); + if !module_outlets.is_empty() { + let mut module_inputs = vec![]; + for outlet in &module_outlets { + module_inputs.push(model_results.outputs[*outlet].clone()); + } + let res = + GraphModules::forward(&module_inputs, visibility.output.clone(), vk, srs)?; + processed_outputs = Some(res.clone()); + let module_results = res.get_result(visibility.output.clone()); + + for (i, outlet) in module_outlets.iter().enumerate() { + model_results.outputs[*outlet] = + Tensor::from(module_results[i].clone().into_iter()); + } + } else { + processed_outputs = Some(GraphModules::forward( + &model_results.outputs, + visibility.output, + vk, + srs, + )?); + } + } + + let witness = GraphWitness { + inputs: original_inputs + .iter() + .map(|t| t.deref().to_vec()) + .collect_vec(), + outputs: model_results + .outputs + .iter() + .map(|t| t.deref().to_vec()) + .collect_vec(), + processed_inputs, + processed_params, + processed_outputs, + max_lookup_inputs: model_results.max_lookup_inputs, + min_lookup_inputs: model_results.min_lookup_inputs, + }; + + #[cfg(not(target_arch = "wasm32"))] + log::debug!( + "witness: \n {}", + &witness.as_json()?.to_colored_json_auto()? + ); + + Ok(witness) + } + + /// Create a new circuit from a set of input data and [RunArgs]. + #[cfg(not(target_arch = "wasm32"))] + pub fn from_run_args( + run_args: &RunArgs, + model_path: &std::path::Path, + ) -> Result> { + let model = Model::from_run_args(run_args, model_path)?; + Self::new(model, run_args) + } + + /// Create a new circuit from a set of input data and [GraphSettings]. + #[cfg(not(target_arch = "wasm32"))] + pub fn from_settings( + params: &GraphSettings, + model_path: &std::path::Path, + check_mode: CheckMode, + ) -> Result> { + params.run_args.validate()?; + let model = Model::from_run_args(¶ms.run_args, model_path)?; + Self::new_from_settings(model, params.clone(), check_mode) + } + + /// + #[cfg(not(target_arch = "wasm32"))] + pub async fn populate_on_chain_test_data( + &mut self, + data: &mut GraphData, + test_on_chain_data: TestOnChainData, + ) -> Result<(), Box> { + // Set up local anvil instance for reading on-chain data + + if matches!( + test_on_chain_data.data_sources.input, + TestDataSource::OnChain + ) { + // if not public then fail + if self.settings().run_args.input_visibility.is_private() { + return Err("Cannot use on-chain data source as private data".into()); + } + + let input_data = match &data.input_data { + DataSource::File(input_data) => input_data, + _ => { + return Err("Cannot use non file source as input for on-chain test. + Manually populate on-chain data from file source instead" + .into()) + } + }; + // Get the flatten length of input_data + // if the input source is a field then set scale to 0 + + let datam: (Vec>, OnChainSource) = OnChainSource::test_from_file_data( + input_data, + self.model().graph.get_input_scales(), + self.model().graph.input_shapes()?, + test_on_chain_data.rpc.as_deref(), + ) + .await?; + data.input_data = datam.1.into(); + } + if matches!( + test_on_chain_data.data_sources.output, + TestDataSource::OnChain + ) { + // if not public then fail + if self.settings().run_args.output_visibility.is_private() { + return Err("Cannot use on-chain data source as private data".into()); + } + + let output_data = match &data.output_data { + Some(DataSource::File(output_data)) => output_data, + Some(DataSource::OnChain(_)) => { + return Err( + "Cannot use on-chain data source as output for on-chain test. + Will manually populate on-chain data from file source instead" + .into(), + ) + } + _ => return Err("No output data found".into()), + }; + let datum: (Vec>, OnChainSource) = OnChainSource::test_from_file_data( + output_data, + self.model().graph.get_output_scales()?, + self.model().graph.output_shapes()?, + test_on_chain_data.rpc.as_deref(), + ) + .await?; + data.output_data = Some(datum.1.into()); + } + // Save the updated GraphData struct to the data_path + data.save(test_on_chain_data.data)?; + Ok(()) + } +} + +#[cfg(not(target_arch = "wasm32"))] +#[derive(Clone, Debug, Default, Serialize, Deserialize)] +struct CircuitSize { + num_instances: usize, + num_advice_columns: usize, + num_fixed: usize, + num_challenges: usize, + num_selectors: usize, +} + +#[cfg(not(target_arch = "wasm32"))] +impl CircuitSize { + pub fn from_cs(cs: &ConstraintSystem) -> Self { + CircuitSize { + num_instances: cs.num_instance_columns(), + num_advice_columns: cs.num_advice_columns(), + num_fixed: cs.num_fixed_columns(), + num_challenges: cs.num_challenges(), + num_selectors: cs.num_selectors(), + } + } + + /// Export the ezkl configuration as json + pub fn as_json(&self) -> Result> { + let serialized = match serde_json::to_string(&self) { + Ok(s) => s, + Err(e) => { + return Err(Box::new(e)); + } + }; + Ok(serialized) + } +} + +impl Circuit for GraphCircuit { + type Config = GraphConfig; + type FloorPlanner = ModulePlanner; + type Params = GraphSettings; + + fn without_witnesses(&self) -> Self { + self.clone() + } + + fn params(&self) -> Self::Params { + // safe to clone because the model is Arc'd + self.settings().clone() + } + + fn configure_with_params(cs: &mut ConstraintSystem, params: Self::Params) -> Self::Config { + let mut params = params.clone(); + params.set_num_blinding_factors(cs.blinding_factors()); + GLOBAL_SETTINGS.with(|settings| { + *settings.borrow_mut() = Some(params.clone()); + }); + let visibility = match VarVisibility::from_args(¶ms.run_args) { + Ok(v) => v, + Err(e) => { + log::error!("failed to create visibility: {:?}", e); + log::warn!("using default visibility"); + VarVisibility::default() + } + }; + + let mut module_configs = ModuleConfigs::from_visibility( + cs, + params.module_sizes.clone(), + params.run_args.logrows as usize, + ); + + let mut vars = ModelVars::new( + cs, + params.run_args.logrows as usize, + params.total_assignments, + params.run_args.num_inner_cols, + params.total_const_size, + params.module_requires_fixed(), + ); + + module_configs.configure_complex_modules(cs, visibility, params.module_sizes.clone()); + + vars.instantiate_instance( + cs, + params.model_instance_shapes, + params.run_args.input_scale, + module_configs.instance, + ); + + let base = Model::configure( + cs, + &vars, + params.run_args.lookup_range, + params.run_args.logrows as usize, + params.required_lookups, + params.check_mode, + ) + .unwrap(); + + let model_config = ModelConfig { base, vars }; + + debug!( + "degree: {}, log2_ceil of degrees: {:?}", + cs.degree(), + (cs.degree() as f32).log2().ceil() + ); + + #[cfg(not(target_arch = "wasm32"))] + info!( + "circuit size: \n {}", + CircuitSize::from_cs(cs) + .as_json() + .unwrap() + .to_colored_json_auto() + .unwrap() + ); + + GraphConfig { + model_config, + module_configs, + } + } + + fn configure(_: &mut ConstraintSystem) -> Self::Config { + unimplemented!("you should call configure_with_params instead") + } + + fn synthesize( + &self, + config: Self::Config, + mut layouter: impl Layouter, + ) -> Result<(), PlonkError> { + trace!("Setting input in synthesize"); + let input_vis = &self.settings().run_args.input_visibility; + let output_vis = &self.settings().run_args.output_visibility; + let mut graph_modules = GraphModules::new(); + + let mut config = config.clone(); + + let mut inputs = self + .graph_witness + .get_input_tensor() + .iter_mut() + .map(|i| { + i.set_visibility(input_vis); + ValTensor::try_from(i.clone()).map_err(|e| { + log::error!("failed to convert input to valtensor: {:?}", e); + PlonkError::Synthesis + }) + }) + .collect::>, PlonkError>>()?; + + let outputs = self + .graph_witness + .get_output_tensor() + .iter_mut() + .map(|i| { + i.set_visibility(output_vis); + ValTensor::try_from(i.clone()).map_err(|e| { + log::error!("failed to convert output to valtensor: {:?}", e); + PlonkError::Synthesis + }) + }) + .collect::>, PlonkError>>()?; + + let mut instance_offset = 0; + trace!("running input module layout"); + + let input_visibility = &self.settings().run_args.input_visibility; + let outlets = input_visibility.overwrites_inputs(); + + if !outlets.is_empty() { + let mut input_outlets = vec![]; + for outlet in &outlets { + input_outlets.push(inputs[*outlet].clone()); + } + graph_modules.layout( + &mut layouter, + &mut config.module_configs, + &mut input_outlets, + input_visibility, + &mut instance_offset, + &self.module_settings.input, + )?; + // replace inputs with the outlets + for (i, outlet) in outlets.iter().enumerate() { + inputs[*outlet] = input_outlets[i].clone(); + } + } else { + graph_modules.layout( + &mut layouter, + &mut config.module_configs, + &mut inputs, + input_visibility, + &mut instance_offset, + &self.module_settings.input, + )?; + } + + // now we need to assign the flattened params to the model + let mut model = self.model().clone(); + let param_visibility = &self.settings().run_args.param_visibility; + trace!("running params module layout"); + if !self.model().get_all_params().is_empty() && param_visibility.requires_processing() { + // now we need to flatten the params + let consts = self.model().get_all_params(); + + let mut flattened_params = { + let mut t = Tensor::new(Some(&consts), &[consts.len()]) + .map_err(|_| { + log::error!("failed to flatten params"); + PlonkError::Synthesis + })? + .combine() + .map_err(|_| { + log::error!("failed to combine params"); + PlonkError::Synthesis + })?; + t.set_visibility(param_visibility); + vec![t.try_into().map_err(|_| { + log::error!("failed to convert params to valtensor"); + PlonkError::Synthesis + })?] + }; + + // now do stuff to the model params + graph_modules.layout( + &mut layouter, + &mut config.module_configs, + &mut flattened_params, + param_visibility, + &mut instance_offset, + &self.module_settings.params, + )?; + + let shapes = self.model().const_shapes(); + trace!("replacing processed consts"); + let split_params = split_valtensor(&flattened_params[0], shapes).map_err(|_| { + log::error!("failed to split params"); + PlonkError::Synthesis + })?; + + // now the flattened_params have been assigned to and we-assign them to the model consts such that they are constrained to be equal + model.replace_consts(&split_params); + } + + // create a new module for the model (space 2) + layouter.assign_region(|| "_enter_module_2", |_| Ok(()))?; + trace!("laying out model"); + + let mut vars = config.model_config.vars.clone(); + vars.set_initial_instance_offset(instance_offset); + + let mut outputs = model + .layout( + config.model_config.clone(), + &mut layouter, + &self.settings().run_args, + &inputs, + &mut vars, + &outputs, + ) + .map_err(|e| { + log::error!("{}", e); + PlonkError::Synthesis + })?; + trace!("running output module layout"); + + let output_visibility = &self.settings().run_args.output_visibility; + let outlets = output_visibility.overwrites_inputs(); + + instance_offset += vars.get_instance_len(); + + if !outlets.is_empty() { + let mut output_outlets = vec![]; + for outlet in &outlets { + output_outlets.push(outputs[*outlet].clone()); + } + // this will re-enter module 0 + graph_modules.layout( + &mut layouter, + &mut config.module_configs, + &mut output_outlets, + &self.settings().run_args.output_visibility, + &mut instance_offset, + &self.module_settings.output, + )?; + + // replace outputs with the outlets + for (i, outlet) in outlets.iter().enumerate() { + outputs[*outlet] = output_outlets[i].clone(); + } + } else { + graph_modules.layout( + &mut layouter, + &mut config.module_configs, + &mut outputs, + &self.settings().run_args.output_visibility, + &mut instance_offset, + &self.module_settings.output, + )?; + } + + Ok(()) + } +} diff --git a/mnist_ezkl/src/graph/model.rs b/mnist_ezkl/src/graph/model.rs new file mode 100644 index 0000000..368d786 --- /dev/null +++ b/mnist_ezkl/src/graph/model.rs @@ -0,0 +1,1562 @@ +use super::extract_const_quantized_values; +use super::node::*; +use super::scale_to_multiplier; +use super::vars::*; +use super::GraphError; +use super::GraphSettings; +use crate::circuit::hybrid::HybridOp; +use crate::circuit::region::RegionCtx; +use crate::circuit::Input; +use crate::circuit::InputType; +use crate::circuit::Unknown; +use crate::fieldutils::felt_to_i128; +use crate::tensor::ValType; +use crate::{ + circuit::{lookup::LookupOp, BaseConfig as PolyConfig, CheckMode, Op}, + tensor::{Tensor, ValTensor}, + RunArgs, +}; +use halo2curves::bn256::Fr as Fp; + +#[cfg(not(target_arch = "wasm32"))] +use colored::Colorize; +use halo2_proofs::{ + circuit::{Layouter, Value}, + plonk::ConstraintSystem, +}; +use halo2curves::ff::Field; +use itertools::Itertools; +use serde::Deserialize; +use serde::Serialize; +#[cfg(not(target_arch = "wasm32"))] +use tract_onnx; +#[cfg(not(target_arch = "wasm32"))] +use tract_onnx::prelude::{ + Framework, Graph, InferenceFact, InferenceModelExt, SymbolValues, TypedFact, TypedOp, +}; +#[cfg(not(target_arch = "wasm32"))] +use tract_onnx::tract_hir::ops::scan::Scan; + +use log::error; +use log::{debug, info, trace}; +use std::collections::BTreeMap; +#[cfg(not(target_arch = "wasm32"))] +use std::collections::HashMap; +use std::collections::HashSet; +use std::error::Error; +use std::fs; +use std::io::Read; +use std::path::PathBuf; +#[cfg(not(target_arch = "wasm32"))] +use tabled::Table; +use unzip_n::unzip_n; + +unzip_n!(pub 3); + +/// The result of a forward pass. +#[derive(Clone, Debug)] +pub struct ForwardResult { + /// The outputs of the forward pass. + pub outputs: Vec>, + /// The maximum value of any input to a lookup operation. + pub max_lookup_inputs: i128, + /// The minimum value of any input to a lookup operation. + pub min_lookup_inputs: i128, +} + +/// A circuit configuration for the entirety of a model loaded from an Onnx file. +#[derive(Clone, Debug)] +pub struct ModelConfig { + /// The base configuration for the circuit + pub base: PolyConfig, + /// A wrapper for holding all columns that will be assigned to by the model + pub vars: ModelVars, +} + +/// Representation of execution graph +pub type NodeGraph = BTreeMap; + +/// A struct for loading from an Onnx file and converting a computational graph to a circuit. +#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)] +pub struct Model { + /// input indices + pub graph: ParsedNodes, + /// Defines which inputs to the model are public and private (params, inputs, outputs) using [VarVisibility]. + pub visibility: VarVisibility, +} + +/// +#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)] +pub enum OutputMapping { + /// + Single { + /// + outlet: usize, + /// + is_state: bool, + }, + /// + Stacked { + /// + outlet: usize, + /// + axis: usize, + /// + is_state: bool, + }, +} + +impl OutputMapping { + /// + pub fn is_state(&self) -> bool { + match self { + OutputMapping::Single { is_state, .. } => *is_state, + OutputMapping::Stacked { is_state, .. } => *is_state, + } + } + + /// + pub fn outlet(&self) -> usize { + match self { + OutputMapping::Single { outlet, .. } => *outlet, + OutputMapping::Stacked { outlet, .. } => *outlet, + } + } +} + +/// +#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)] +pub enum InputMapping { + /// + Full, + /// + State, + /// + Stacked { + /// + axis: usize, + /// + chunk: usize, + }, +} + +fn number_of_iterations(mappings: &[InputMapping], dims: Vec<&[usize]>) -> usize { + let mut number_of_iterations = + dims.iter() + .zip(mappings) + .filter_map(|(dims, mapping)| match mapping { + InputMapping::Stacked { axis, chunk } => Some( + // number of iterations given the dim size along the axis + // and the chunk size + (dims[*axis] + chunk - 1) / chunk, + ), + _ => None, + }); + // assert all collected number of iterations are equal + assert!(number_of_iterations.clone().all_equal()); + + number_of_iterations.next().unwrap_or(1) +} + +fn input_state_idx(input_mappings: &[InputMapping]) -> Vec { + input_mappings + .iter() + .enumerate() + .filter(|(_, r)| matches!(r, InputMapping::State)) + .map(|(index, _)| index) + .collect::>() +} + +fn output_state_idx(output_mappings: &[Vec]) -> Vec { + output_mappings + .iter() + .flatten() + .filter_map(|x| if x.is_state() { Some(x.outlet()) } else { None }) + .collect::>() +} + +/// Enables model as subnode of other models +#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)] +pub enum NodeType { + /// A node in the model + Node(Node), + /// A submodel + SubGraph { + /// The subgraph + model: Model, + /// The subgraph's inputs + inputs: Vec, + /// the subgraph's idx within the parent graph + idx: usize, + /// output mappings + output_mappings: Vec>, + /// input mappings + input_mappings: Vec, + /// + out_dims: Vec>, + /// + out_scales: Vec, + }, +} + +impl NodeType { + /// + pub fn is_lookup(&self) -> bool { + match self { + NodeType::Node(n) => n.opkind.is_lookup(), + NodeType::SubGraph { .. } => false, + } + } + /// + pub fn num_uses(&self) -> usize { + match self { + NodeType::Node(n) => n.num_uses, + NodeType::SubGraph { .. } => 0, + } + } + + /// Returns the indices of the node's inputs. + pub fn inputs(&self) -> Vec { + match self { + NodeType::Node(n) => n.inputs.clone(), + NodeType::SubGraph { inputs, .. } => inputs.clone(), + } + } + + /// Returns the dimensions of the node's output. + pub fn out_dims(&self) -> Vec> { + match self { + NodeType::Node(n) => vec![n.out_dims.clone()], + NodeType::SubGraph { out_dims, .. } => out_dims.clone(), + } + } + /// Returns the lookups required by a graph + pub fn required_lookups(&self) -> Vec { + match self { + NodeType::Node(n) => n.opkind.required_lookups(), + NodeType::SubGraph { model, .. } => model.required_lookups(), + } + } + /// Returns the scales of the node's output. + pub fn out_scales(&self) -> Vec { + match self { + NodeType::Node(n) => vec![n.out_scale], + NodeType::SubGraph { out_scales, .. } => out_scales.clone(), + } + } + + /// Returns a string representation of the operation. + pub fn as_str(&self) -> String { + match self { + NodeType::Node(n) => n.opkind.as_string(), + NodeType::SubGraph { .. } => "SUBGRAPH".into(), + } + } + + /// Returns true if the operation is a rebase + pub fn is_rebase(&self) -> bool { + match self { + NodeType::Node(n) => matches!(n.opkind, SupportedOp::RebaseScale { .. }), + NodeType::SubGraph { .. } => false, + } + } + + /// Returns true if the operation is an input. + pub fn is_input(&self) -> bool { + match self { + NodeType::Node(n) => n.opkind.is_input(), + NodeType::SubGraph { .. } => false, + } + } + /// Returns true if the operation is a const. + pub fn is_constant(&self) -> bool { + match self { + NodeType::Node(n) => n.opkind.is_constant(), + NodeType::SubGraph { .. } => false, + } + } + + /// Returns the node's unique identifier. + pub fn idx(&self) -> usize { + match self { + NodeType::Node(n) => n.idx, + NodeType::SubGraph { idx, .. } => *idx, + } + } + + /// decrement const num times used + pub fn decrement_use(&mut self) { + match self { + NodeType::Node(n) => n.num_uses -= 1, + NodeType::SubGraph { .. } => log::warn!("Cannot decrement const of subgraph"), + } + } + + /// bunp scale of node + pub fn bump_scale(&mut self, scale: crate::Scale) { + match self { + NodeType::Node(n) => n.out_scale = scale, + NodeType::SubGraph { .. } => log::warn!("Cannot bump scale of subgraph"), + } + } + + /// Replace the operation kind of the node. + pub fn replace_opkind(&mut self, opkind: SupportedOp) { + match self { + NodeType::Node(n) => n.opkind = opkind, + NodeType::SubGraph { .. } => log::warn!("Cannot replace opkind of subgraph"), + } + } + + /// Returns the operation kind of the node (if any). + pub fn opkind(&self) -> SupportedOp { + match self { + NodeType::Node(n) => n.opkind.clone(), + NodeType::SubGraph { .. } => SupportedOp::Unknown(Unknown), + } + } +} + +#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)] +/// A set of EZKL nodes that represent a computational graph. +pub struct ParsedNodes { + /// The nodes in the graph. + pub nodes: BTreeMap, + inputs: Vec, + outputs: Vec, +} + +impl ParsedNodes { + /// Returns the number of the computational graph's inputs + pub fn num_inputs(&self) -> usize { + let input_nodes = self.inputs.iter(); + input_nodes.len() + } + + /// Input types + pub fn get_input_types(&self) -> Result, GraphError> { + self.inputs + .iter() + .map(|o| { + match self + .nodes + .get(o) + .ok_or(GraphError::MissingNode(*o))? + .opkind() + { + SupportedOp::Input(Input { datum_type, .. }) => Ok(datum_type.clone()), + _ => Err(GraphError::InvalidInputTypes), + } + }) + .collect::, _>>() + } + + /// Returns shapes of the computational graph's inputs + pub fn input_shapes(&self) -> Result>, Box> { + let mut inputs = vec![]; + + for input in self.inputs.iter() { + let node = self + .nodes + .get(input) + .ok_or(GraphError::MissingNode(*input))?; + let input_dims = node.out_dims(); + let input_dim = input_dims.get(0).ok_or(GraphError::MissingNode(*input))?; + inputs.push(input_dim.clone()); + } + + Ok(inputs) + } + + /// Returns the number of the computational graph's outputs + pub fn num_outputs(&self) -> usize { + let output_nodes = self.outputs.iter(); + output_nodes.len() + } + + /// Returns shapes of the computational graph's outputs + pub fn output_shapes(&self) -> Result>, GraphError> { + let mut outputs = vec![]; + + for output in self.outputs.iter() { + let (idx, outlet) = output; + let node = self.nodes.get(idx).ok_or(GraphError::MissingNode(*idx))?; + let out_dims = node.out_dims(); + let out_dim = out_dims + .get(*outlet) + .ok_or(GraphError::MissingNode(*outlet))?; + outputs.push(out_dim.clone()); + } + + Ok(outputs) + } + + /// Returns the fixed point scale of the computational graph's inputs + pub fn get_input_scales(&self) -> Vec { + let input_nodes = self.inputs.iter(); + input_nodes + .map(|idx| { + self.nodes + .get(idx) + .ok_or(GraphError::MissingNode(*idx)) + .map(|n| n.out_scales()) + .unwrap_or_default() + }) + .flatten() + .collect() + } + + /// Returns the fixed point scale of the computational graph's outputs + pub fn get_output_scales(&self) -> Result, GraphError> { + let output_nodes = self.outputs.iter(); + output_nodes + .map(|(idx, outlet)| { + Ok(self + .nodes + .get(idx) + .ok_or(GraphError::MissingNode(*idx))? + .out_scales()[*outlet]) + }) + .collect::, GraphError>>() + } +} + +impl Model { + fn required_lookups(&self) -> Vec { + self.graph + .nodes + .values() + .flat_map(|n| n.required_lookups()) + .collect_vec() + } + + /// Creates a `Model` from a specified path to an Onnx file. + /// # Arguments + /// * `reader` - A reader for an Onnx file. + /// * `run_args` - [RunArgs] + #[cfg(not(target_arch = "wasm32"))] + pub fn new(reader: &mut dyn std::io::Read, run_args: &RunArgs) -> Result> { + let visibility = VarVisibility::from_args(run_args)?; + + let graph = Self::load_onnx_model(reader, run_args, &visibility)?; + + let om = Model { graph, visibility }; + + debug!("\n {}", om.table_nodes()); + + Ok(om) + } + + /// + pub fn save(&self, path: PathBuf) -> Result<(), Box> { + let f = std::fs::File::create(path)?; + let writer = std::io::BufWriter::new(f); + bincode::serialize_into(writer, &self)?; + Ok(()) + } + + /// + pub fn load(path: PathBuf) -> Result> { + // read bytes from file + let mut f = std::fs::File::open(&path)?; + let metadata = fs::metadata(&path)?; + let mut buffer = vec![0; metadata.len() as usize]; + f.read_exact(&mut buffer)?; + let result = bincode::deserialize(&buffer)?; + Ok(result) + } + + /// Generate model parameters for the circuit + pub fn gen_params( + &self, + run_args: &RunArgs, + check_mode: CheckMode, + ) -> Result> { + let instance_shapes = self.instance_shapes()?; + #[cfg(not(target_arch = "wasm32"))] + info!( + "{} {} {}", + "model has".blue(), + instance_shapes.len().to_string().blue(), + "instances".blue() + ); + // this is the total number of variables we will need to allocate + // for the circuit + let (num_rows, linear_coord, total_const_size) = + self.dummy_layout(run_args, &self.graph.input_shapes()?)?; + + // extract the requisite lookup ops from the model + let mut lookup_ops: Vec = self.required_lookups(); + + // if we're using percentage tolerance, we need to add the necessary range check ops for it. + + if run_args.tolerance.val > 0.0 { + for scale in self.graph.get_output_scales()? { + let mut tolerance = run_args.tolerance; + tolerance.scale = scale_to_multiplier(scale).into(); + let opkind: Box> = Box::new(HybridOp::RangeCheck(tolerance)); + lookup_ops.extend(opkind.required_lookups()); + } + } + + let set: HashSet<_> = lookup_ops.drain(..).collect(); // dedup + lookup_ops.extend(set.into_iter().sorted()); + + Ok(GraphSettings { + run_args: run_args.clone(), + model_instance_shapes: instance_shapes, + module_sizes: crate::graph::modules::ModuleSizes::default(), + num_rows, + total_assignments: linear_coord, + required_lookups: lookup_ops, + model_output_scales: self.graph.get_output_scales()?, + model_input_scales: self.graph.get_input_scales(), + total_const_size, + check_mode, + version: env!("CARGO_PKG_VERSION").to_string(), + num_blinding_factors: None, + }) + } + + /// Runs a forward pass on sample data ! + /// # Arguments + /// * `reader` - A reader for an Onnx file. + /// * `model_inputs` - A vector of [Tensor]s to use as inputs to the model. + /// * `run_args` - [RunArgs] + pub fn forward(&self, model_inputs: &[Tensor]) -> Result> { + let mut results: BTreeMap<&usize, Vec>> = BTreeMap::new(); + let mut max_lookup_inputs = 0; + let mut min_lookup_inputs = 0; + + let input_shapes = self.graph.input_shapes()?; + + for (i, input_idx) in self.graph.inputs.iter().enumerate() { + let mut input = model_inputs[i].clone(); + input.reshape(&input_shapes[i])?; + results.insert(input_idx, vec![input]); + } + + for (idx, n) in self.graph.nodes.iter() { + let mut inputs = vec![]; + if n.is_input() { + let t = results.get(idx).ok_or(GraphError::MissingResults)?[0].clone(); + inputs.push(t); + } else { + for (idx, outlet) in n.inputs().iter() { + match results.get(&idx) { + Some(value) => inputs.push(value[*outlet].clone()), + None => return Err(Box::new(GraphError::MissingNode(*idx))), + } + } + }; + + debug!("executing {}: {}", idx, n.as_str()); + debug!("dims: {:?}", n.out_dims()); + debug!( + "input_dims: {:?}", + inputs.iter().map(|x| x.dims()).collect::>() + ); + + if n.is_lookup() { + let (mut min, mut max) = (0, 0); + for i in &inputs { + max = max.max( + i.iter() + .map(|x| felt_to_i128(*x)) + .max() + .ok_or("missing max")?, + ); + min = min.min( + i.iter() + .map(|x| felt_to_i128(*x)) + .min() + .ok_or("missing min")?, + ); + } + max_lookup_inputs = max_lookup_inputs.max(max); + min_lookup_inputs = min_lookup_inputs.min(min); + debug!("max lookup inputs: {}", max); + debug!("min lookup inputs: {}", min); + } + + match n { + NodeType::Node(n) => { + // execute the op + let start = instant::Instant::now(); + let res = Op::::f(&n.opkind, &inputs)?; + let elapsed = start.elapsed(); + trace!("op took: {:?}", elapsed); + // see if any of the intermediate lookup calcs are the max + if !res.intermediate_lookups.is_empty() { + let (mut min, mut max) = (0, 0); + for i in &res.intermediate_lookups { + max = max.max(i.clone().into_iter().max().ok_or("missing max")?); + min = min.min(i.clone().into_iter().min().ok_or("missing min")?); + } + max_lookup_inputs = max_lookup_inputs.max(max); + min_lookup_inputs = min_lookup_inputs.min(min); + debug!("intermediate max lookup inputs: {}", max); + debug!("intermediate min lookup inputs: {}", min); + } + debug!( + "------------ output node int {}: {} \n ------------ float: {} \n ------------ max: {} \n ------------ min: {}", + idx, + res.output.map(crate::fieldutils::felt_to_i32).show(), + res.output + .map(|x| crate::fieldutils::felt_to_f64(x) + / scale_to_multiplier(n.out_scale)) + .show(), + res.output.clone().into_iter().map(crate::fieldutils::felt_to_i128).max().unwrap_or(0), + res.output.clone().into_iter().map(crate::fieldutils::felt_to_i128).min().unwrap_or(0), + ); + results.insert(idx, vec![res.output]); + } + NodeType::SubGraph { + model, + output_mappings, + input_mappings, + inputs: input_tuple, + .. + } => { + let orig_inputs = inputs.clone(); + let input_mappings = input_mappings.clone(); + + let input_dims = inputs.iter().map(|inp| inp.dims()); + let num_iter = number_of_iterations(&input_mappings, input_dims.collect()); + + debug!( + "{} iteration(s) in a subgraph with inputs {:?} and sources {:?}", + num_iter, input_tuple, model.graph.inputs + ); + + debug!("input_mappings: {:?}", input_mappings); + + let mut full_results: Vec> = vec![]; + + for i in 0..num_iter { + // replace the Stacked input with the current chunk iter + for ((mapping, inp), og_input) in + input_mappings.iter().zip(&mut inputs).zip(&orig_inputs) + { + if let InputMapping::Stacked { axis, chunk } = mapping { + let start = i * chunk; + let end = (i + 1) * chunk; + let t = crate::tensor::ops::slice(og_input, axis, &start, &end)?; + *inp = t; + } + } + + let res = model.forward(&inputs)?; + // recursively get the max lookup inputs for subgraphs + max_lookup_inputs = max_lookup_inputs.max(res.max_lookup_inputs); + min_lookup_inputs = min_lookup_inputs.min(res.min_lookup_inputs); + + let mut outlets = BTreeMap::new(); + for (mappings, outlet_res) in output_mappings.iter().zip(res.outputs) { + for mapping in mappings { + match mapping { + OutputMapping::Single { outlet, .. } => { + outlets.insert(outlet, outlet_res.clone()); + } + OutputMapping::Stacked { outlet, axis, .. } => { + if !full_results.is_empty() { + let stacked_res = crate::tensor::ops::concat( + &[&full_results[*outlet], &outlet_res], + *axis, + )?; + + outlets.insert(outlet, stacked_res); + } else { + outlets.insert(outlet, outlet_res.clone()); + } + } + } + } + } + + full_results = outlets.into_values().collect_vec(); + + let output_states = output_state_idx(output_mappings); + let input_states = input_state_idx(&input_mappings); + + assert_eq!(input_states.len(), output_states.len()); + + for (input_idx, output_idx) in input_states.iter().zip(output_states) { + inputs[*input_idx] = full_results[output_idx].clone(); + } + } + + trace!( + "------------ output subgraph node {}: {:?}", + idx, + full_results + .iter() + .map(|x| + // convert to tensor i32 + x.map(crate::fieldutils::felt_to_i32).show()) + .collect_vec() + ); + + results.insert(idx, full_results); + } + } + } + + let output_nodes = self.graph.outputs.iter(); + debug!( + "model outputs are nodes: {:?}", + output_nodes.clone().collect_vec() + ); + let outputs = output_nodes + .map(|(idx, outlet)| { + Ok(results.get(&idx).ok_or(GraphError::MissingResults)?[*outlet].clone()) + }) + .collect::, GraphError>>()?; + + let res = ForwardResult { + outputs, + max_lookup_inputs, + min_lookup_inputs, + }; + + Ok(res) + } + + /// Loads an Onnx model from a specified path. + /// # Arguments + /// * `reader` - A reader for an Onnx file. + /// * `scale` - The scale to use for quantization. + /// * `public_params` - Whether to make the params public. + #[cfg(not(target_arch = "wasm32"))] + fn load_onnx_model( + reader: &mut dyn std::io::Read, + run_args: &RunArgs, + visibility: &VarVisibility, + ) -> Result> { + use tract_onnx::tract_hir::internal::GenericFactoid; + + let start_time = instant::Instant::now(); + + let mut model = tract_onnx::onnx().model_for_read(reader).map_err(|e| { + error!("Error loading model: {}", e); + GraphError::ModelLoad + })?; + + let variables: std::collections::HashMap = + std::collections::HashMap::from_iter(run_args.variables.clone()); + + for (i, id) in model.clone().inputs.iter().enumerate() { + let input = model.node_mut(id.node); + let mut fact: InferenceFact = input.outputs[0].fact.clone(); + + for (i, x) in fact.clone().shape.dims().enumerate() { + if matches!(x, GenericFactoid::Any) { + let batch_size = match variables.get("batch_size") { + Some(x) => x, + None => return Err("Unknown dimension batch_size in model inputs, set batch_size in variables".into()), + }; + fact.shape + .set_dim(i, tract_onnx::prelude::TDim::Val(*batch_size as i64)); + } + } + + model.set_input_fact(i, fact)?; + } + + for (i, _) in model.clone().outputs.iter().enumerate() { + model.set_output_fact(i, InferenceFact::default())?; + } + // Note: do not optimize the model, as the layout will depend on underlying hardware + let mut model = model.into_typed()?.into_decluttered()?; + let mut symbol_values = SymbolValues::default(); + for (symbol, value) in run_args.variables.iter() { + let symbol = model.symbol_table.sym(symbol); + symbol_values = symbol_values.with(&symbol, *value as i64); + info!("set {} to {}", symbol, value); + } + model = model.concretize_dims(&symbol_values)?; + + let scales = VarScales::from_args(run_args)?; + let nodes = Self::nodes_from_graph( + &model, + run_args, + &scales, + visibility, + &symbol_values, + None, + None, + )?; + + debug!("\n {}", model); + + let parsed_nodes = ParsedNodes { + nodes, + inputs: model.inputs.iter().map(|o| o.node).collect(), + outputs: model.outputs.iter().map(|o| (o.node, o.slot)).collect(), + }; + + let duration = start_time.elapsed(); + trace!("model loading took: {:?}", duration); + + Ok(parsed_nodes) + } + + /// Formats nodes (including subgraphs) into tables ! + #[cfg(not(target_arch = "wasm32"))] + pub fn table_nodes(&self) -> String { + let mut node_accumulator = vec![]; + let mut string = String::new(); + for (idx, node) in &self.graph.nodes { + match node { + NodeType::Node(n) => { + node_accumulator.push(n); + } + NodeType::SubGraph { model, inputs, .. } => { + let mut table = Table::new(node_accumulator.iter()); + table.with(tabled::settings::Style::modern()); + table.with(tabled::settings::Shadow::new(1)); + table.with( + tabled::settings::style::BorderColor::default() + .top(tabled::settings::Color::BG_YELLOW), + ); + string = format!("{} \n\n MAIN GRAPH \n\n{}", string, table); + node_accumulator = vec![]; + string = format!( + "{}\n\n SUBGRAPH AT IDX {} WITH INPUTS {:?}\n{}", + string, + idx, + inputs, + model.table_nodes(), + ); + } + } + } + + let mut table = Table::new(node_accumulator.iter()); + table.with(tabled::settings::Style::modern()); + format!("{} \n{}", string, table) + } + + /// Creates ezkl nodes from a tract graph + /// # Arguments + /// * `graph` - A tract graph. + /// * `run_args` - [RunArgs] + /// * `visibility` - Which inputs to the model are public and private (params, inputs, outputs) using [VarVisibility]. + /// * `input_scales` - The scales of the model's inputs. + + #[cfg(not(target_arch = "wasm32"))] + pub fn nodes_from_graph( + graph: &Graph>, + run_args: &RunArgs, + scales: &VarScales, + visibility: &VarVisibility, + symbol_values: &SymbolValues, + override_input_scales: Option>, + override_output_scales: Option>, + ) -> Result, Box> { + use crate::graph::node_output_shapes; + + let mut nodes = BTreeMap::::new(); + let mut input_idx = 0; + for (i, n) in graph.nodes.iter().enumerate() { + // Extract the slope layer hyperparams + match n.op().downcast_ref::() { + Some(b) => { + let model = b.body.clone(); + let input_scales = n + .inputs + .iter() + .map(|i| { + Ok(nodes + .get(&i.node) + .ok_or(GraphError::MissingNode(i.node))? + .out_scales()[0]) + }) + .collect::, GraphError>>()?; + + let mut input_mappings = vec![]; + for mapping in &b.input_mapping { + match mapping { + tract_onnx::tract_hir::ops::scan::InputMapping::Scan(info) => { + input_mappings.push(InputMapping::Stacked { + axis: info.axis, + chunk: info.chunk as usize, + }); + } + tract_onnx::tract_hir::ops::scan::InputMapping::State => { + input_mappings.push(InputMapping::State); + } + tract_onnx::tract_hir::ops::scan::InputMapping::Full => { + input_mappings.push(InputMapping::Full); + } + } + } + + let input_state_idx = input_state_idx(&input_mappings); + + let mut output_mappings = vec![]; + for mapping in b.output_mapping.iter() { + let mut mappings = vec![]; + if let Some(outlet) = mapping.last_value_slot { + mappings.push(OutputMapping::Single { + outlet, + is_state: mapping.state, + }); + } + if let Some(last) = mapping.scan { + mappings.push(OutputMapping::Stacked { + outlet: last.0, + axis: last.1.axis, + is_state: false, + }); + } + output_mappings.push(mappings); + } + + let output_state_idx = output_state_idx(&output_mappings); + + let mut output_scale_override = HashMap::new(); + // if input_state_idx and output_state_idx have mismatched scales we need to rebase the scale of the output node + for (input_idx, output_idx) in input_state_idx.iter().zip(output_state_idx) { + let input_scale = input_scales[*input_idx]; + // output mappings is a vec of vec. we need to find the outer index of the output node we want to rebase. + let mut traversed_len = 0; + for (outer_idx, mappings) in output_mappings.iter().enumerate() { + let mapping_len = mappings.len(); + if traversed_len + mapping_len > output_idx { + let output_node_idx = b.body.outputs[outer_idx].node; + output_scale_override.insert(output_node_idx, input_scale); + } + traversed_len += mapping_len; + } + } + + let subgraph_nodes = Self::nodes_from_graph( + &model, + run_args, + scales, + visibility, + symbol_values, + Some(input_scales.clone()), + Some(output_scale_override), + )?; + + let subgraph = ParsedNodes { + nodes: subgraph_nodes, + inputs: model.inputs.iter().map(|o| o.node).collect(), + outputs: model.outputs.iter().map(|o| (o.node, o.slot)).collect(), + }; + + let om = Model { + graph: subgraph, + visibility: visibility.clone(), + }; + + let out_dims = node_output_shapes(n)? + .iter() + .map(|shape| Ok(shape.as_ref().ok_or("missing shape dims")?.clone())) + .collect::, Box>>()?; + + let mut output_scales = BTreeMap::new(); + + for (i, _mapping) in b.output_mapping.iter().enumerate() { + for mapping in b.output_mapping.iter() { + if let Some(outlet) = mapping.last_value_slot { + output_scales.insert(outlet, om.graph.get_output_scales()?[i]); + } + if let Some(last) = mapping.scan { + output_scales.insert(last.0, om.graph.get_output_scales()?[i]); + } + } + } + + let out_scales = output_scales.into_values().collect_vec(); + + nodes.insert( + i, + NodeType::SubGraph { + model: om, + inputs: n.inputs.iter().map(|i| (i.node, i.slot)).collect_vec(), + idx: i, + output_mappings, + input_mappings, + out_dims, + out_scales, + }, + ); + } + None => { + let mut n = Node::new( + n.clone(), + &mut nodes, + scales, + &run_args.param_visibility, + i, + symbol_values, + )?; + if let Some(ref scales) = override_input_scales { + if let Some(inp) = n.opkind.get_input() { + let scale = scales[input_idx]; + n.opkind = SupportedOp::Input(Input { + scale, + datum_type: inp.datum_type, + }); + input_idx += 1; + n.out_scale = scale; + } + } + if let Some(ref scales) = override_output_scales { + if scales.contains_key(&i) { + let scale_diff = n.out_scale - scales[&i]; + n.opkind = if scale_diff > 0 { + RebaseScale::rebase(n.opkind, scales[&i], n.out_scale, 1) + } else { + RebaseScale::rebase_up(n.opkind, scales[&i], n.out_scale) + }; + n.out_scale = scales[&i]; + } + } + nodes.insert(i, NodeType::Node(n)); + } + } + } + Self::remove_unused_nodes(&mut nodes); + + Ok(nodes) + } + + #[cfg(not(target_arch = "wasm32"))] + /// Removes all nodes that are consts with 0 uses + fn remove_unused_nodes(nodes: &mut BTreeMap) { + // remove all nodes that are consts with 0 uses now + nodes.retain(|_, n| match n { + NodeType::Node(n) => match &mut n.opkind { + SupportedOp::Constant(c) => { + c.empty_raw_value(); + n.num_uses > 0 + } + _ => n.num_uses > 0, + }, + NodeType::SubGraph { model, .. } => { + Self::remove_unused_nodes(&mut model.graph.nodes); + true + } + }); + } + + /// Creates a `Model` from parsed run_args + /// # Arguments + /// * `params` - A [GraphSettings] struct holding parsed CLI arguments. + #[cfg(not(target_arch = "wasm32"))] + pub fn from_run_args( + run_args: &RunArgs, + model: &std::path::Path, + ) -> Result> { + Model::new( + &mut std::fs::File::open(model) + .map_err(|_| format!("failed to load model at {}", model.display()))?, + run_args, + ) + } + + /// Configures a model for the circuit + /// # Arguments + /// * `meta` - The constraint system. + /// * `vars` - The variables for the circuit. + /// * `run_args` - [RunArgs] + /// * `required_lookups` - The required lookup operations for the circuit. + pub fn configure( + meta: &mut ConstraintSystem, + vars: &ModelVars, + lookup_range: (i128, i128), + logrows: usize, + required_lookups: Vec, + check_mode: CheckMode, + ) -> Result, Box> { + info!("configuring model"); + + let mut base_gate = PolyConfig::configure( + meta, + vars.advices[0..2].try_into()?, + &vars.advices[2], + check_mode, + ); + // set scale for HybridOp::RangeCheck and call self.conf_lookup on that op for percentage tolerance case + let input = &vars.advices[0]; + let output = &vars.advices[1]; + let index = &vars.advices[2]; + for op in required_lookups { + base_gate.configure_lookup(meta, input, output, index, lookup_range, logrows, &op)?; + } + + Ok(base_gate) + } + + /// Assigns values to the regions created when calling `configure`. + /// # Arguments + /// * `config` - [ModelConfig] holding all node configs. + /// * `layouter` - Halo2 Layouter. + /// * `inputs` - The values to feed into the circuit. + /// * `vars` - The variables for the circuit. + pub fn layout( + &self, + mut config: ModelConfig, + layouter: &mut impl Layouter, + run_args: &RunArgs, + inputs: &[ValTensor], + vars: &mut ModelVars, + witnessed_outputs: &[ValTensor], + ) -> Result>, Box> { + info!("model layout..."); + + let start_time = instant::Instant::now(); + + let mut results = BTreeMap::>>::new(); + + let input_shapes = self.graph.input_shapes()?; + for (i, input_idx) in self.graph.inputs.iter().enumerate() { + if self.visibility.input.is_public() { + let instance = vars.instance.as_ref().ok_or("no instance")?.clone(); + results.insert(*input_idx, vec![instance]); + vars.increment_instance_idx(); + } else { + let mut input = inputs[i].clone(); + input.reshape(&input_shapes[i])?; + results.insert(*input_idx, vec![input]); + } + } + + let instance_idx = vars.get_instance_idx(); + + config.base.layout_tables(layouter)?; + + let mut num_rows = 0; + let mut linear_coord = 0; + let mut total_const_size = 0; + + let outputs = layouter.assign_region( + || "model", + |region| { + let mut thread_safe_region = RegionCtx::new(region, 0, run_args.num_inner_cols); + // we need to do this as this loop is called multiple times + vars.set_instance_idx(instance_idx); + + let outputs = self + .layout_nodes(&mut config, &mut thread_safe_region, &mut results) + .map_err(|e| { + error!("{}", e); + halo2_proofs::plonk::Error::Synthesis + })?; + + if run_args.output_visibility.is_public() || run_args.output_visibility.is_fixed() { + let output_scales = self.graph.get_output_scales().map_err(|e| { + error!("{}", e); + halo2_proofs::plonk::Error::Synthesis + })?; + let res = outputs + .iter() + .enumerate() + .map(|(i, output)| { + let mut tolerance = run_args.tolerance; + tolerance.scale = scale_to_multiplier(output_scales[i]).into(); + + let comparators = if run_args.output_visibility == Visibility::Public { + let res = vars.instance.as_ref().ok_or("no instance")?.clone(); + vars.increment_instance_idx(); + res + } else { + // if witnessed_outputs is of len less than i error + if witnessed_outputs.len() <= i { + return Err("you provided insufficient witness values to generate a fixed output".into()); + } + witnessed_outputs[i].clone() + }; + + config.base.layout( + &mut thread_safe_region, + &[output.clone(), comparators], + Box::new(HybridOp::RangeCheck(tolerance)), + ) + }) + .collect::,_>>(); + res.map_err(|e| { + error!("{}", e); + halo2_proofs::plonk::Error::Synthesis + })?; + } else if !run_args.output_visibility.is_private() { + for output in &outputs { + thread_safe_region.increment_total_constants(output.num_constants()); + } + } + num_rows = thread_safe_region.row(); + linear_coord = thread_safe_region.linear_coord(); + total_const_size = thread_safe_region.total_constants(); + + Ok(outputs) + }, + )?; + + // Then number of columns in the circuits + #[cfg(not(target_arch = "wasm32"))] + info!( + "{} {} {} (coord={}, constants={})", + "model uses".blue(), + num_rows.to_string().blue(), + "rows".blue(), + linear_coord.to_string().yellow(), + total_const_size.to_string().red() + ); + + let duration = start_time.elapsed(); + trace!("model layout took: {:?}", duration); + + Ok(outputs) + } + + fn layout_nodes( + &self, + config: &mut ModelConfig, + region: &mut RegionCtx, + results: &mut BTreeMap>>, + ) -> Result>, Box> { + // index over results to get original inputs + let orig_inputs: BTreeMap = results + .clone() + .into_iter() + .filter(|(idx, _)| self.graph.inputs.contains(idx)) + .collect(); + + for (idx, node) in self.graph.nodes.iter() { + let mut values: Vec> = if !node.is_input() { + node.inputs() + .iter() + .map(|(idx, outlet)| { + Ok(results.get(idx).ok_or(GraphError::MissingResults)?[*outlet].clone()) + }) + .collect::, GraphError>>()? + } else { + // we re-assign inputs, always from the 0 outlet + vec![results.get(idx).ok_or(GraphError::MissingResults)?[0].clone()] + }; + + debug!( + "laying out {}: {}, row:{}, coord:{}, total_constants: {}", + idx, + node.as_str(), + region.row(), + region.linear_coord(), + region.total_constants() + ); + debug!("dims: {:?}", node.out_dims()); + debug!( + "input_dims {:?}", + values.iter().map(|v| v.dims()).collect_vec() + ); + + match &node { + NodeType::Node(n) => { + let res = if node.is_constant() && node.num_uses() == 1 { + log::debug!("node {} is a constant with 1 use", n.idx); + let mut node = n.clone(); + let c = node.opkind.get_mutable_constant().ok_or("no constant")?; + Some(c.quantized_values.clone().try_into()?) + } else { + config + .base + .layout(region, &values, n.opkind.clone_dyn()) + .map_err(|e| { + error!("{}", e); + halo2_proofs::plonk::Error::Synthesis + })? + }; + + if let Some(vt) = res { + // we get the max as for fused nodes this corresponds to the node output + results.insert(*idx, vec![vt.clone()]); + //only use with mock prover + debug!("------------ output node {:?}: {:?}", idx, vt.show()); + } + } + NodeType::SubGraph { + model, + inputs, + output_mappings, + input_mappings, + .. + } => { + let original_values = values.clone(); + let input_mappings = input_mappings.clone(); + + let input_dims = values.iter().map(|inp| inp.dims()); + let num_iter = number_of_iterations(&input_mappings, input_dims.collect()); + + debug!( + "{} iteration(s) in a subgraph with inputs {:?} and sources {:?}", + num_iter, inputs, model.graph.inputs + ); + + let mut full_results: Vec> = vec![]; + + for i in 0..num_iter { + debug!(" -------------- subgraph iteration: {}", i); + // replace the Stacked input with the current chunk iter + for ((mapping, inp), og_inp) in + input_mappings.iter().zip(&mut values).zip(&original_values) + { + if let InputMapping::Stacked { axis, chunk } = mapping { + let start = i * chunk; + let end = (i + 1) * chunk; + let mut sliced_input = og_inp.clone(); + sliced_input.slice(axis, &start, &end)?; + *inp = sliced_input; + } + } + + let mut subgraph_results = BTreeMap::from_iter( + model + .graph + .inputs + .clone() + .into_iter() + .zip(values.clone().into_iter().map(|v| vec![v])), + ); + + let res = model.layout_nodes(config, region, &mut subgraph_results)?; + + let mut outlets = BTreeMap::new(); + + for (mappings, outlet_res) in output_mappings.iter().zip(res) { + for mapping in mappings { + match mapping { + OutputMapping::Single { outlet, .. } => { + outlets.insert(outlet, outlet_res.clone()); + } + OutputMapping::Stacked { outlet, axis, .. } => { + if !full_results.is_empty() { + let stacked_res = full_results[*outlet] + .clone() + .concat_axis(outlet_res.clone(), axis)?; + + outlets.insert(outlet, stacked_res); + } else { + outlets.insert(outlet, outlet_res.clone()); + } + } + } + } + } + + full_results = outlets.into_values().collect_vec(); + + let output_states = output_state_idx(output_mappings); + let input_states = input_state_idx(&input_mappings); + + assert_eq!(input_states.len(), output_states.len()); + + for (input_idx, output_idx) in input_states.iter().zip(output_states) { + values[*input_idx] = full_results[output_idx].clone(); + } + } + + //only use with mock prover + trace!( + "------------ output subgraph node {:?}: {:?}", + idx, + full_results.iter().map(|x| x.show()).collect_vec() + ); + + results.insert(*idx, full_results); + } + } + } + + // we do this so we can support multiple passes of the same model and have deterministic results (Non-assigned inputs etc... etc...) + results.extend(orig_inputs); + + let output_nodes = self.graph.outputs.iter(); + debug!( + "model outputs are nodes: {:?}", + output_nodes.clone().collect_vec() + ); + let outputs = output_nodes + .map(|(idx, outlet)| { + Ok(results.get(idx).ok_or(GraphError::MissingResults)?[*outlet].clone()) + }) + .collect::, GraphError>>()?; + + Ok(outputs) + } + + /// Assigns dummy values to the regions created when calling `configure`. + /// # Arguments + /// * `input_shapes` - The shapes of the inputs to the model. + pub fn dummy_layout( + &self, + run_args: &RunArgs, + input_shapes: &[Vec], + ) -> Result<(usize, usize, usize), Box> { + info!("calculating num of constraints using dummy model layout..."); + + let start_time = instant::Instant::now(); + + let mut results = BTreeMap::>>::new(); + let default_value = if !self.visibility.input.is_fixed() { + ValType::Value(Value::::unknown()) + } else { + ValType::Constant(Fp::ONE) + }; + + let inputs: Vec> = input_shapes + .iter() + .map(|shape| { + let mut t: ValTensor = + vec![default_value.clone(); shape.iter().product()].into(); + t.reshape(shape)?; + Ok(t) + }) + .collect::, Box>>()?; + + for (i, input_idx) in self.graph.inputs.iter().enumerate() { + results.insert(*input_idx, vec![inputs[i].clone()]); + } + + let mut dummy_config = + PolyConfig::dummy(run_args.logrows as usize, run_args.num_inner_cols); + let mut model_config = ModelConfig { + base: dummy_config.clone(), + vars: ModelVars::new_dummy(), + }; + + let mut region = RegionCtx::new_dummy(0, run_args.num_inner_cols); + + let outputs = self.layout_nodes(&mut model_config, &mut region, &mut results)?; + + if self.visibility.output.is_public() || self.visibility.output.is_fixed() { + let default_value = if !self.visibility.output.is_fixed() { + ValType::Value(Value::::unknown()) + } else { + ValType::Constant(Fp::ONE) + }; + + let comparator = outputs + .iter() + .map(|x| { + let mut v: ValTensor = + vec![default_value.clone(); x.dims().iter().product::()].into(); + v.reshape(x.dims())?; + Ok(v) + }) + .collect::, Box>>()?; + + let _ = outputs + .into_iter() + .zip(comparator) + .map(|(o, c)| { + dummy_config.layout( + &mut region, + &[o, c], + Box::new(HybridOp::RangeCheck(run_args.tolerance)), + ) + }) + .collect::, _>>()?; + } else if !self.visibility.output.is_private() { + for output in &outputs { + region.increment_total_constants(output.num_constants()); + } + } + + let duration = start_time.elapsed(); + trace!("dummy model layout took: {:?}", duration); + + // Then number of columns in the circuits + #[cfg(not(target_arch = "wasm32"))] + info!( + "{} {} {} (coord={}, constants={})", + "model uses".blue(), + region.row().to_string().blue(), + "rows".blue(), + region.linear_coord().to_string().yellow(), + region.total_constants().to_string().red() + ); + + Ok(( + region.row(), + region.linear_coord(), + region.total_constants(), + )) + } + + /// Retrieves all constants from the model. + pub fn get_all_params(&self) -> Vec> { + let mut params = vec![]; + for node in self.graph.nodes.values() { + match node { + NodeType::Node(_) => { + if let Some(constant) = extract_const_quantized_values(node.opkind()) { + params.push(constant); + } + } + NodeType::SubGraph { model, .. } => { + params.extend(model.get_all_params()); + } + } + } + params + } + + /// Shapes of the computational graph's constants + pub fn const_shapes(&self) -> Vec> { + let mut const_shapes = vec![]; + for node in self.graph.nodes.values() { + match node { + NodeType::Node(_) => { + if let Some(constant) = extract_const_quantized_values(node.opkind()) { + const_shapes.push(constant.dims().to_vec()); + }; + } + NodeType::SubGraph { model, .. } => { + const_shapes.extend(model.const_shapes()); + } + } + } + const_shapes + } + + /// Replaces all constants in the model with the provided values (in order of indexing), returns the number of consts + pub fn replace_consts(&mut self, consts: &[ValTensor]) -> usize { + let mut const_idx = 0; + for node in self.graph.nodes.values_mut() { + match node { + NodeType::Node(n) => { + if let SupportedOp::Constant(c) = &n.opkind { + let mut op = crate::circuit::Constant::new( + c.quantized_values.clone(), + c.raw_values.clone(), + ); + op.pre_assign(consts[const_idx].clone()); + n.opkind = SupportedOp::Constant(op); + + const_idx += 1; + } + } + NodeType::SubGraph { model, .. } => { + let total_consts = model.replace_consts(&consts[const_idx..]); + const_idx += total_consts; + } + } + } + const_idx + } + + /// Shapes of the computational graph's public inputs (if any) + pub fn instance_shapes(&self) -> Result>, Box> { + let mut instance_shapes = vec![]; + if self.visibility.input.is_public() { + instance_shapes.extend(self.graph.input_shapes()?); + } + if self.visibility.output.is_public() { + instance_shapes.extend(self.graph.output_shapes()?); + } + Ok(instance_shapes) + } +} diff --git a/mnist_ezkl/src/graph/modules.rs b/mnist_ezkl/src/graph/modules.rs new file mode 100644 index 0000000..f2dceba --- /dev/null +++ b/mnist_ezkl/src/graph/modules.rs @@ -0,0 +1,501 @@ +use crate::circuit::modules::elgamal::{ElGamalConfig, ElGamalGadget, ElGamalVariables}; +use crate::circuit::modules::kzg::{KZGChip, KZGConfig}; +use crate::circuit::modules::poseidon::spec::{PoseidonSpec, POSEIDON_RATE, POSEIDON_WIDTH}; +use crate::circuit::modules::poseidon::{PoseidonChip, PoseidonConfig}; +use crate::circuit::modules::Module; +use crate::tensor::{Tensor, ValTensor, ValType}; +use halo2_proofs::circuit::{Layouter, Value}; +use halo2_proofs::plonk::{Column, ConstraintSystem, Error, Instance, VerifyingKey}; +use halo2_proofs::poly::kzg::commitment::ParamsKZG; +use halo2curves::bn256::{Bn256, Fr as Fp, G1Affine}; +use itertools::Itertools; +use serde::{Deserialize, Serialize}; + +use super::GraphWitness; +use super::{VarVisibility, Visibility}; + +/// poseidon len to hash in tree +pub const POSEIDON_LEN_GRAPH: usize = 32; + +/// ElGamal number of instances +pub const ELGAMAL_INSTANCES: usize = 4; +/// Poseidon number of instancess +pub const POSEIDON_INSTANCES: usize = 1; + +/// Poseidon module type +pub type ModulePoseidon = + PoseidonChip; +/// Poseidon module config +pub type ModulePoseidonConfig = PoseidonConfig; + +/// +#[derive(Clone, Debug, Default)] +pub struct ModuleConfigs { + /// KZG + kzg: Vec, + /// Poseidon + poseidon: Option, + /// ElGamal + elgamal: Option, + /// Instance + pub instance: Option>, +} + +impl ModuleConfigs { + /// Create new module configs from visibility of each variable + pub fn from_visibility( + cs: &mut ConstraintSystem, + module_size: ModuleSizes, + logrows: usize, + ) -> Self { + let mut config = Self::default(); + + for size in module_size.kzg { + config.kzg.push(KZGChip::configure(cs, (logrows, size))); + } + + config + } + + /// Configure the modules + pub fn configure_complex_modules( + &mut self, + cs: &mut ConstraintSystem, + visibility: VarVisibility, + module_size: ModuleSizes, + ) { + if (visibility.input.is_encrypted() + || visibility.output.is_encrypted() + || visibility.params.is_encrypted()) + && module_size.elgamal.1[0] > 0 + { + let elgamal = ElGamalGadget::configure(cs, ()); + self.instance = Some(elgamal.instance); + self.elgamal = Some(elgamal); + }; + + if (visibility.input.is_hashed() + || visibility.output.is_hashed() + || visibility.params.is_hashed()) + && module_size.poseidon.1[0] > 0 + { + if visibility.input.is_hashed_public() + || visibility.output.is_hashed_public() + || visibility.params.is_hashed_public() + { + if let Some(inst) = self.instance { + self.poseidon = Some(ModulePoseidon::configure_with_optional_instance( + cs, + Some(inst), + )); + } else { + let poseidon = ModulePoseidon::configure(cs, ()); + self.instance = poseidon.instance; + self.poseidon = Some(poseidon); + } + } else if visibility.input.is_hashed_private() + || visibility.output.is_hashed_private() + || visibility.params.is_hashed_private() + { + self.poseidon = Some(ModulePoseidon::configure_with_optional_instance(cs, None)); + } + }; + } +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +/// Module variable settings +pub struct ModuleVarSettings { + /// + elgamal: Option, +} + +impl ModuleVarSettings { + /// Create new module variable settings + pub fn new(elgamal: ElGamalVariables) -> Self { + ModuleVarSettings { + elgamal: Some(elgamal), + } + } +} + +impl Default for ModuleVarSettings { + fn default() -> Self { + let dummy_elgamal = ElGamalVariables::default(); + ModuleVarSettings { + elgamal: Some(dummy_elgamal), + } + } +} + +#[derive(Clone, Debug, Default, Serialize, Deserialize)] +/// Module input settings +pub struct ModuleSettings { + /// + pub input: ModuleVarSettings, + /// + pub params: ModuleVarSettings, + /// + pub output: ModuleVarSettings, +} + +impl From<&GraphWitness> for ModuleSettings { + fn from(graph_input: &GraphWitness) -> Self { + let mut settings = Self::default(); + + if let Some(processed_inputs) = &graph_input.processed_inputs { + if let Some(elgamal_result) = &processed_inputs.elgamal { + settings.input = ModuleVarSettings::new(elgamal_result.variables.clone()); + } + } + if let Some(processed_params) = &graph_input.processed_params { + if let Some(elgamal_result) = &processed_params.elgamal { + settings.params = ModuleVarSettings::new(elgamal_result.variables.clone()); + } + } + if let Some(processed_outputs) = &graph_input.processed_outputs { + if let Some(elgamal_result) = &processed_outputs.elgamal { + settings.output = ModuleVarSettings::new(elgamal_result.variables.clone()); + } + } + + settings + } +} + +#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, Eq)] +/// Result from ElGamal +pub struct ElGamalResult { + /// ElGamal variables + pub variables: ElGamalVariables, + /// ElGamal ciphertexts + pub ciphertexts: Vec>, + /// ElGamal encrypted message + pub encrypted_messages: Vec>, +} + +/// Result from a forward pass +#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, Eq)] +pub struct ModuleForwardResult { + /// The inputs of the forward pass for poseidon + pub poseidon_hash: Option>, + /// The outputs of the forward pass for ElGamal + pub elgamal: Option, + /// The outputs of the forward pass for KZG + pub kzg_commit: Option>>, +} + +impl ModuleForwardResult { + /// Get the result + pub fn get_result(&self, vis: Visibility) -> Vec> { + if vis.is_hashed() { + self.poseidon_hash + .clone() + .unwrap() + .into_iter() + .map(|x| vec![x]) + .collect() + } else if vis.is_encrypted() { + self.elgamal.clone().unwrap().encrypted_messages + } else { + vec![] + } + } + + /// get instances + pub fn get_instances(&self) -> Vec> { + if let Some(poseidon) = &self.poseidon_hash { + poseidon.iter().map(|x| vec![*x]).collect() + } else if let Some(elgamal) = &self.elgamal { + elgamal.ciphertexts.clone() + } else { + vec![] + } + } +} + +#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)] +/// +pub struct ModuleSizes { + kzg: Vec, + poseidon: (usize, Vec), + elgamal: (usize, Vec), +} + +impl ModuleSizes { + /// Create new module sizes + pub fn new() -> Self { + ModuleSizes { + kzg: vec![], + poseidon: ( + 0, + vec![0; crate::circuit::modules::poseidon::NUM_INSTANCE_COLUMNS], + ), + elgamal: ( + 0, + vec![0; crate::circuit::modules::elgamal::NUM_INSTANCE_COLUMNS], + ), + } + } + + /// Get the number of constraints + pub fn max_constraints(&self) -> usize { + self.poseidon.0.max(self.elgamal.0) + } + /// Get the number of instances + pub fn num_instances(&self) -> Vec { + // concat + self.poseidon + .1 + .iter() + .chain(self.elgamal.1.iter()) + .copied() + .collect_vec() + } +} + +/// Graph modules that can process inputs, params and outputs beyond the basic operations +#[derive(Clone, Debug, Default, Serialize, Deserialize)] +pub struct GraphModules { + kzg_idx: usize, +} +impl GraphModules { + /// + pub fn new() -> GraphModules { + GraphModules { kzg_idx: 0 } + } + + /// + pub fn reset_index(&mut self) { + self.kzg_idx = 0; + } +} + +impl GraphModules { + fn num_constraint_given_shapes( + visibility: Visibility, + shapes: Vec>, + sizes: &mut ModuleSizes, + ) { + for shape in shapes { + let total_len = shape.iter().product::(); + if total_len > 0 { + if visibility.is_kzgcommit() { + // 1 constraint for each kzg commitment + sizes.kzg.push(total_len); + } else if visibility.is_hashed() { + sizes.poseidon.0 += ModulePoseidon::num_rows(total_len); + // 1 constraints for hash + sizes.poseidon.1[0] += 1; + } else if visibility.is_encrypted() { + // add the 1 time fixed cost of maingate + ecc chips + sizes.elgamal.0 += ElGamalGadget::num_rows(total_len); + // 4 constraints for each ciphertext c1, c2, and sk + sizes.elgamal.1[0] += 4; + } + } + } + } + /// Get the number of constraints and instances for the module + pub fn num_constraints_and_instances( + input_shapes: Vec>, + params_shapes: Vec>, + output_shapes: Vec>, + visibility: VarVisibility, + ) -> ModuleSizes { + let mut module_sizes = ModuleSizes::new(); + + Self::num_constraint_given_shapes(visibility.input, input_shapes, &mut module_sizes); + Self::num_constraint_given_shapes(visibility.params, params_shapes, &mut module_sizes); + Self::num_constraint_given_shapes(visibility.output, output_shapes, &mut module_sizes); + + module_sizes + } + + /// Layout the module + fn layout_module( + module: &impl Module, + layouter: &mut impl Layouter, + x: &mut Vec>, + instance_offset: &mut usize, + ) -> Result<(), Error> { + // reserve module 0 for ... modules + // hash the input and replace the constrained cells in the input + let cloned_x = (*x).clone(); + x[0] = module + .layout(layouter, &cloned_x, instance_offset.to_owned()) + .unwrap(); + for inc in module.instance_increment_input().iter() { + // increment the instance offset to make way for future module layouts + *instance_offset += inc; + } + + Ok(()) + } + + /// Layout the module + pub fn layout( + &mut self, + layouter: &mut impl Layouter, + configs: &mut ModuleConfigs, + values: &mut [ValTensor], + element_visibility: &Visibility, + instance_offset: &mut usize, + module_settings: &ModuleVarSettings, + ) -> Result<(), Error> { + if element_visibility.is_kzgcommit() && !values.is_empty() { + // concat values and sk to get the inputs + let mut inputs = values.iter_mut().map(|x| vec![x.clone()]).collect_vec(); + + // layout the module + inputs.iter_mut().for_each(|x| { + // create the module + let chip = KZGChip::new(configs.kzg[self.kzg_idx].clone()); + // reserve module 2 onwards for kzg modules + let module_offset = 3 + self.kzg_idx; + layouter + .assign_region(|| format!("_enter_module_{}", module_offset), |_| Ok(())) + .unwrap(); + Self::layout_module(&chip, layouter, x, instance_offset).unwrap(); + // increment the current index + self.kzg_idx += 1; + }); + + // replace the inputs with the outputs + values.iter_mut().enumerate().for_each(|(i, x)| { + x.clone_from(&inputs[i][0]); + }); + } + + // If the module is hashed, then we need to hash the inputs + if element_visibility.is_hashed() && !values.is_empty() { + if let Some(config) = &mut configs.poseidon { + // reserve module 0 for poseidon modules + layouter.assign_region(|| "_enter_module_0", |_| Ok(()))?; + // create the module + let chip = ModulePoseidon::new(config.clone()); + // concat values and sk to get the inputs + let mut inputs = values.iter_mut().map(|x| vec![x.clone()]).collect_vec(); + // layout the module + inputs.iter_mut().for_each(|x| { + Self::layout_module(&chip, layouter, x, instance_offset).unwrap(); + }); + // replace the inputs with the outputs + values.iter_mut().enumerate().for_each(|(i, x)| { + x.clone_from(&inputs[i][0]); + }); + } else { + log::error!("Poseidon config not initialized"); + return Err(Error::Synthesis); + } + // If the module is encrypted, then we need to encrypt the inputs + } else if element_visibility.is_encrypted() && !values.is_empty() { + if let Some(config) = &mut configs.elgamal { + // reserve module 1 for elgamal modules + layouter.assign_region(|| "_enter_module_1", |_| Ok(()))?; + // create the module + let mut chip = ElGamalGadget::new(config.clone()); + // load the variables + let variables = module_settings.elgamal.as_ref().unwrap().clone(); + chip.load_variables(variables.clone()); + // load the sk: + let sk: Tensor> = + Tensor::new(Some(&[Value::known(variables.sk).into()]), &[1]).unwrap(); + // concat values and sk to get the inputs + let mut inputs = values + .iter_mut() + .map(|x| vec![x.clone(), sk.clone().into()]) + .collect_vec(); + // layout the module + inputs.iter_mut().for_each(|x| { + Self::layout_module(&chip, layouter, x, instance_offset).unwrap(); + chip.config.initialized = true; + }); + // replace the inputs with the outputs + values.iter_mut().enumerate().for_each(|(i, x)| { + x.clone_from(&inputs[i][0]); + }); + + config.initialized = true; + } else { + log::error!("ElGamal config not initialized"); + return Err(Error::Synthesis); + } + } + + Ok(()) + } + + /// Run forward pass + pub fn forward( + inputs: &[Tensor], + element_visibility: Visibility, + vk: Option<&VerifyingKey>, + srs: Option<&ParamsKZG>, + ) -> Result> { + let mut rng = &mut rand::thread_rng(); + let mut poseidon_hash = None; + let mut elgamal = None; + let mut kzg_commit = None; + + if element_visibility.is_hashed() { + let field_elements = inputs.iter().fold(vec![], |mut acc, x| { + let res = ModulePoseidon::run(x.to_vec()).unwrap()[0].clone(); + acc.extend(res); + acc + }); + poseidon_hash = Some(field_elements); + } + + if element_visibility.is_kzgcommit() { + if let Some(vk) = vk { + if let Some(srs) = srs { + let commitments = inputs.iter().fold(vec![], |mut acc, x| { + let res = KZGChip::commit( + x.to_vec(), + vk.cs().degree() as u32, + (vk.cs().blinding_factors() + 1) as u32, + srs, + ); + acc.push(res); + acc + }); + kzg_commit = Some(commitments); + } else { + log::warn!("no srs provided for kzgcommit. processed value will be none"); + } + } else { + log::debug!( + "no verifying key provided for kzgcommit. processed value will be none" + ); + } + } + + if element_visibility.is_encrypted() { + let variables = ElGamalVariables::gen_random(&mut rng); + let ciphertexts = inputs.iter().fold(vec![], |mut acc, x| { + let res = ElGamalGadget::run((x.to_vec(), variables.clone())).unwrap(); + acc.extend(res); + acc + }); + + let encrypted_messages = inputs.iter().fold(vec![], |mut acc, x| { + let res = ElGamalGadget::encrypt(variables.pk, x.to_vec(), variables.r).c2; + acc.push(res); + acc + }); + + elgamal = Some(ElGamalResult { + variables, + ciphertexts, + encrypted_messages, + }); + } + + Ok(ModuleForwardResult { + poseidon_hash, + elgamal, + kzg_commit, + }) + } +} diff --git a/mnist_ezkl/src/graph/node.rs b/mnist_ezkl/src/graph/node.rs new file mode 100644 index 0000000..6b23cf8 --- /dev/null +++ b/mnist_ezkl/src/graph/node.rs @@ -0,0 +1,746 @@ +use super::scale_to_multiplier; +#[cfg(not(target_arch = "wasm32"))] +use super::utilities::node_output_shapes; +#[cfg(not(target_arch = "wasm32"))] +use super::VarScales; +#[cfg(not(target_arch = "wasm32"))] +use super::Visibility; +use crate::circuit::hybrid::HybridOp; +use crate::circuit::lookup::LookupOp; +use crate::circuit::poly::PolyOp; +use crate::circuit::Constant; +use crate::circuit::Input; +use crate::circuit::Op; +use crate::circuit::Unknown; +use crate::fieldutils::felt_to_i128; +use crate::fieldutils::i128_to_felt; +#[cfg(not(target_arch = "wasm32"))] +use crate::graph::new_op_from_onnx; +use crate::tensor::Tensor; +use crate::tensor::TensorError; +use halo2curves::bn256::Fr as Fp; +#[cfg(not(target_arch = "wasm32"))] +use itertools::Itertools; +#[cfg(not(target_arch = "wasm32"))] +use log::trace; +use serde::Deserialize; +use serde::Serialize; +#[cfg(not(target_arch = "wasm32"))] +use std::collections::BTreeMap; +use std::error::Error; +#[cfg(not(target_arch = "wasm32"))] +use std::fmt; +#[cfg(not(target_arch = "wasm32"))] +use tabled::Tabled; +#[cfg(not(target_arch = "wasm32"))] +use tract_onnx::{ + self, + prelude::{Node as OnnxNode, SymbolValues, TypedFact, TypedOp}, +}; + +#[cfg(not(target_arch = "wasm32"))] +fn display_vector(v: &Vec) -> String { + if !v.is_empty() { + format!("{:?}", v) + } else { + String::new() + } +} + +#[cfg(not(target_arch = "wasm32"))] +fn display_opkind(v: &SupportedOp) -> String { + v.as_string() +} + +/// A wrapper for an operation that has been rescaled. +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct Rescaled { + /// The operation that has to be rescaled. + pub inner: Box, + /// The scale of the operation's inputs. + pub scale: Vec<(usize, u128)>, +} + +impl Op for Rescaled { + fn as_any(&self) -> &dyn std::any::Any { + self + } + fn f(&self, x: &[Tensor]) -> Result, TensorError> { + if self.scale.len() != x.len() { + return Err(TensorError::DimMismatch("rescaled inputs".to_string())); + } + + let mut rescaled_inputs = vec![]; + let inputs = &mut x.to_vec(); + for (i, ri) in inputs.iter_mut().enumerate() { + let mult_tensor = Tensor::from([Fp::from(self.scale[i].1 as u64)].into_iter()); + let res = (ri.clone() * mult_tensor)?; + rescaled_inputs.push(res); + } + Op::::f(&*self.inner, &rescaled_inputs) + } + + fn as_string(&self) -> String { + format!("RESCALED INPUT ({})", self.inner.as_string()) + } + + fn out_scale(&self, in_scales: Vec) -> Result> { + let in_scales = in_scales + .into_iter() + .zip(self.scale.iter()) + .map(|(a, b)| a + crate::graph::multiplier_to_scale(b.1 as f64)) + .collect(); + + Op::::out_scale(&*self.inner, in_scales) + } + + fn required_lookups(&self) -> Vec { + self.inner.required_lookups() + } + + fn layout( + &self, + config: &mut crate::circuit::BaseConfig, + region: &mut crate::circuit::region::RegionCtx, + values: &[crate::tensor::ValTensor], + ) -> Result>, Box> { + if self.scale.len() != values.len() { + return Err(Box::new(TensorError::DimMismatch( + "rescaled inputs".to_string(), + ))); + } + + let res = + &crate::circuit::layouts::rescale(config, region, values[..].try_into()?, &self.scale)? + [..]; + self.inner.layout(config, region, res) + } + + fn clone_dyn(&self) -> Box> { + Box::new(self.clone()) // Forward to the derive(Clone) impl + } +} + +/// A wrapper for an operation that has been rescaled. +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct RebaseScale { + /// The operation that has to be rescaled. + pub inner: Box, + /// the multiplier applied to the node output + pub multiplier: f64, + /// scale being rebased to + pub target_scale: i32, + /// The original scale of the operation's inputs. + pub original_scale: i32, +} + +impl RebaseScale { + /// + pub fn rebase( + inner: SupportedOp, + global_scale: crate::Scale, + op_out_scale: crate::Scale, + scale_rebase_multiplier: u32, + ) -> SupportedOp { + if (op_out_scale > (global_scale * scale_rebase_multiplier as i32)) + && !inner.is_constant() + && !inner.is_input() + { + let multiplier = + scale_to_multiplier(op_out_scale - global_scale * scale_rebase_multiplier as i32); + if let Some(op) = inner.get_rebased() { + SupportedOp::RebaseScale(RebaseScale { + inner: op.inner.clone(), + target_scale: op.target_scale, + multiplier: op.multiplier * multiplier, + original_scale: op.original_scale, + }) + } else { + SupportedOp::RebaseScale(RebaseScale { + inner: Box::new(inner), + target_scale: global_scale * scale_rebase_multiplier as i32, + multiplier, + original_scale: op_out_scale, + }) + } + } else { + inner + } + } + + /// + pub fn rebase_up( + inner: SupportedOp, + target_scale: crate::Scale, + op_out_scale: crate::Scale, + ) -> SupportedOp { + if (op_out_scale < (target_scale)) && !inner.is_constant() && !inner.is_input() { + let multiplier = scale_to_multiplier(op_out_scale - target_scale); + if let Some(op) = inner.get_rebased() { + SupportedOp::RebaseScale(RebaseScale { + inner: op.inner.clone(), + target_scale: op.target_scale, + multiplier: op.multiplier * multiplier, + original_scale: op.original_scale, + }) + } else { + SupportedOp::RebaseScale(RebaseScale { + inner: Box::new(inner), + target_scale, + multiplier, + original_scale: op_out_scale, + }) + } + } else { + inner + } + } +} + +impl Op for RebaseScale { + fn as_any(&self) -> &dyn std::any::Any { + self + } + fn f(&self, x: &[Tensor]) -> Result, TensorError> { + let mut res = Op::::f(&*self.inner, x)?; + + let ri = res.output.map(felt_to_i128); + let rescaled = crate::tensor::ops::nonlinearities::const_div(&ri, self.multiplier); + res.output = rescaled.map(i128_to_felt); + + res.intermediate_lookups.push(ri); + + Ok(res) + } + + fn as_string(&self) -> String { + format!( + "REBASED (div={:?}) ({})", + self.multiplier, + self.inner.as_string() + ) + } + + fn out_scale(&self, _: Vec) -> Result> { + Ok(self.target_scale) + } + + fn required_lookups(&self) -> Vec { + let mut lookups = self.inner.required_lookups(); + lookups.push(LookupOp::Div { + denom: crate::circuit::utils::F32(self.multiplier as f32), + }); + lookups + } + + fn layout( + &self, + config: &mut crate::circuit::BaseConfig, + region: &mut crate::circuit::region::RegionCtx, + values: &[crate::tensor::ValTensor], + ) -> Result>, Box> { + let original_res = self + .inner + .layout(config, region, values)? + .ok_or("no layout")?; + + Ok(Some(crate::circuit::layouts::nonlinearity( + config, + region, + &[original_res], + &LookupOp::Div { + denom: crate::circuit::utils::F32(self.multiplier as f32), + }, + )?)) + } + + fn clone_dyn(&self) -> Box> { + Box::new(self.clone()) // Forward to the derive(Clone) impl + } +} + +/// A single operation in a [crate::graph::Model]. +#[derive(Clone, Debug, Serialize, Deserialize)] +pub enum SupportedOp { + /// A linear operation. + Linear(PolyOp), + /// A nonlinear operation. + Nonlinear(LookupOp), + /// A hybrid operation. + Hybrid(HybridOp), + /// + Input(Input), + /// + Constant(Constant), + /// + Unknown(Unknown), + /// + Rescaled(Rescaled), + /// + RebaseScale(RebaseScale), +} + +impl SupportedOp { + /// + pub fn is_lookup(&self) -> bool { + match self { + SupportedOp::Nonlinear(_) => true, + SupportedOp::RebaseScale(op) => op.inner.is_lookup(), + _ => false, + } + } + /// + pub fn get_input(&self) -> Option { + match self { + SupportedOp::Input(op) => Some(op.clone()), + _ => None, + } + } + + /// + pub fn get_rebased(&self) -> Option<&RebaseScale> { + match self { + SupportedOp::RebaseScale(op) => Some(op), + _ => None, + } + } + + /// + pub fn get_lookup(&self) -> Option<&LookupOp> { + match self { + SupportedOp::Nonlinear(op) => Some(op), + _ => None, + } + } + + /// + pub fn get_constant(&self) -> Option<&Constant> { + match self { + SupportedOp::Constant(op) => Some(op), + _ => None, + } + } + + /// + pub fn get_mutable_constant(&mut self) -> Option<&mut Constant> { + match self { + SupportedOp::Constant(op) => Some(op), + _ => None, + } + } + + #[cfg(not(target_arch = "wasm32"))] + fn homogenous_rescale( + &self, + in_scales: Vec, + ) -> Result>, Box> { + let inputs_to_scale = self.requires_homogenous_input_scales(); + // creates a rescaled op if the inputs are not homogenous + let op = self.clone_dyn(); + super::homogenize_input_scales(op, in_scales, inputs_to_scale) + } +} + +impl From>> for SupportedOp { + fn from(value: Box>) -> Self { + if let Some(op) = value.as_any().downcast_ref::>() { + return SupportedOp::Linear(op.clone()); + }; + + if let Some(op) = value.as_any().downcast_ref::() { + return SupportedOp::Nonlinear(op.clone()); + }; + + if let Some(op) = value.as_any().downcast_ref::() { + return SupportedOp::Hybrid(op.clone()); + }; + + if let Some(op) = value.as_any().downcast_ref::() { + return SupportedOp::Input(op.clone()); + }; + + if let Some(op) = value.as_any().downcast_ref::>() { + return SupportedOp::Constant(op.clone()); + }; + + if let Some(op) = value.as_any().downcast_ref::() { + return SupportedOp::Unknown(op.clone()); + }; + if let Some(op) = value.as_any().downcast_ref::() { + return SupportedOp::Rescaled(op.clone()); + }; + if let Some(op) = value.as_any().downcast_ref::() { + return SupportedOp::RebaseScale(op.clone()); + }; + + log::error!("Unsupported op type"); + log::warn!("defaulting to Unknown"); + SupportedOp::Unknown(Unknown {}) + } +} + +impl Op for SupportedOp { + fn f( + &self, + inputs: &[Tensor], + ) -> Result, crate::tensor::TensorError> { + match self { + SupportedOp::Linear(op) => op.f(inputs), + SupportedOp::Nonlinear(op) => op.f(inputs), + SupportedOp::Hybrid(op) => op.f(inputs), + SupportedOp::Input(op) => op.f(inputs), + SupportedOp::Constant(op) => op.f(inputs), + SupportedOp::Unknown(op) => op.f(inputs), + SupportedOp::Rescaled(op) => op.f(inputs), + SupportedOp::RebaseScale(op) => op.f(inputs), + } + } + + fn layout( + &self, + config: &mut crate::circuit::BaseConfig, + region: &mut crate::circuit::region::RegionCtx, + values: &[crate::tensor::ValTensor], + ) -> Result>, Box> { + match self { + SupportedOp::Linear(op) => op.layout(config, region, values), + SupportedOp::Nonlinear(op) => op.layout(config, region, values), + SupportedOp::Hybrid(op) => op.layout(config, region, values), + SupportedOp::Input(op) => op.layout(config, region, values), + SupportedOp::Constant(op) => op.layout(config, region, values), + SupportedOp::Unknown(op) => op.layout(config, region, values), + SupportedOp::Rescaled(op) => op.layout(config, region, values), + SupportedOp::RebaseScale(op) => op.layout(config, region, values), + } + } + + fn is_input(&self) -> bool { + match self { + SupportedOp::Linear(op) => Op::::is_input(op), + SupportedOp::Nonlinear(op) => Op::::is_input(op), + SupportedOp::Hybrid(op) => Op::::is_input(op), + SupportedOp::Input(op) => Op::::is_input(op), + SupportedOp::Constant(op) => Op::::is_input(op), + SupportedOp::Unknown(op) => Op::::is_input(op), + SupportedOp::Rescaled(op) => Op::::is_input(op), + SupportedOp::RebaseScale(op) => Op::::is_input(op), + } + } + + fn is_constant(&self) -> bool { + match self { + SupportedOp::Linear(op) => Op::::is_constant(op), + SupportedOp::Nonlinear(op) => Op::::is_constant(op), + SupportedOp::Hybrid(op) => Op::::is_constant(op), + SupportedOp::Input(op) => Op::::is_constant(op), + SupportedOp::Constant(op) => Op::::is_constant(op), + SupportedOp::Unknown(op) => Op::::is_constant(op), + SupportedOp::Rescaled(op) => Op::::is_constant(op), + SupportedOp::RebaseScale(op) => Op::::is_constant(op), + } + } + + fn requires_homogenous_input_scales(&self) -> Vec { + match self { + SupportedOp::Linear(op) => Op::::requires_homogenous_input_scales(op), + SupportedOp::Nonlinear(op) => Op::::requires_homogenous_input_scales(op), + SupportedOp::Hybrid(op) => Op::::requires_homogenous_input_scales(op), + SupportedOp::Input(op) => Op::::requires_homogenous_input_scales(op), + SupportedOp::Constant(op) => Op::::requires_homogenous_input_scales(op), + SupportedOp::Unknown(op) => Op::::requires_homogenous_input_scales(op), + SupportedOp::Rescaled(op) => Op::::requires_homogenous_input_scales(op), + SupportedOp::RebaseScale(op) => Op::::requires_homogenous_input_scales(op), + } + } + + fn clone_dyn(&self) -> Box> { + match self { + SupportedOp::Linear(op) => Box::new(op.clone()), + SupportedOp::Nonlinear(op) => Box::new(op.clone()), + SupportedOp::Hybrid(op) => Box::new(op.clone()), + SupportedOp::Input(op) => Box::new(op.clone()), + SupportedOp::Constant(op) => Box::new(op.clone()), + SupportedOp::Unknown(op) => Box::new(op.clone()), + SupportedOp::Rescaled(op) => Box::new(op.clone()), + SupportedOp::RebaseScale(op) => Box::new(op.clone()), + } + } + + fn as_string(&self) -> String { + match self { + SupportedOp::Linear(op) => Op::::as_string(op), + SupportedOp::Nonlinear(op) => Op::::as_string(op), + SupportedOp::Hybrid(op) => Op::::as_string(op), + SupportedOp::Input(op) => Op::::as_string(op), + SupportedOp::Constant(op) => Op::::as_string(op), + SupportedOp::Unknown(op) => Op::::as_string(op), + SupportedOp::Rescaled(op) => Op::::as_string(op), + SupportedOp::RebaseScale(op) => Op::::as_string(op), + } + } + + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn required_lookups(&self) -> Vec { + match self { + SupportedOp::Linear(op) => Op::::required_lookups(op), + SupportedOp::Nonlinear(op) => Op::::required_lookups(op), + SupportedOp::Hybrid(op) => Op::::required_lookups(op), + SupportedOp::Input(op) => Op::::required_lookups(op), + SupportedOp::Constant(op) => Op::::required_lookups(op), + SupportedOp::Unknown(op) => Op::::required_lookups(op), + SupportedOp::Rescaled(op) => Op::::required_lookups(op), + SupportedOp::RebaseScale(op) => Op::::required_lookups(op), + } + } + + fn out_scale(&self, in_scales: Vec) -> Result> { + match self { + SupportedOp::Linear(op) => Op::::out_scale(op, in_scales), + SupportedOp::Nonlinear(op) => Op::::out_scale(op, in_scales), + SupportedOp::Hybrid(op) => Op::::out_scale(op, in_scales), + SupportedOp::Input(op) => Op::::out_scale(op, in_scales), + SupportedOp::Constant(op) => Op::::out_scale(op, in_scales), + SupportedOp::Unknown(op) => Op::::out_scale(op, in_scales), + SupportedOp::Rescaled(op) => Op::::out_scale(op, in_scales), + SupportedOp::RebaseScale(op) => Op::::out_scale(op, in_scales), + } + } +} + +/// A node's input is a tensor from another node's output. +pub type Outlet = (usize, usize); + +/// A single operation in a [crate::graph::Model]. +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct Node { + /// [Op] i.e what operation this node represents. + pub opkind: SupportedOp, + /// The denominator in the fixed point representation for the node's output. Tensors of differing scales should not be combined. + pub out_scale: i32, + // Usually there is a simple in and out shape of the node as an operator. For example, an Affine node has three input_shapes (one for the input, weight, and bias), + // but in_dim is [in], out_dim is [out] + /// The indices of the node's inputs. + pub inputs: Vec, + /// Dimensions of output. + pub out_dims: Vec, + /// The node's unique identifier. + pub idx: usize, + /// The node's num of uses + pub num_uses: usize, +} + +#[cfg(not(target_arch = "wasm32"))] +impl Tabled for Node { + const LENGTH: usize = 6; + + fn headers() -> Vec> { + let mut headers = Vec::with_capacity(Self::LENGTH); + for i in [ + "idx", + "opkind", + "out_scale", + "inputs", + "out_dims", + "required_lookups", + ] { + headers.push(std::borrow::Cow::Borrowed(i)); + } + headers + } + + fn fields(&self) -> Vec> { + let mut fields = Vec::with_capacity(Self::LENGTH); + fields.push(std::borrow::Cow::Owned(self.idx.to_string())); + fields.push(std::borrow::Cow::Owned(display_opkind(&self.opkind))); + fields.push(std::borrow::Cow::Owned(self.out_scale.to_string())); + fields.push(std::borrow::Cow::Owned(display_vector(&self.inputs))); + fields.push(std::borrow::Cow::Owned(display_vector(&self.out_dims))); + fields.push(std::borrow::Cow::Owned(format!( + "{:?}", + self.opkind + .required_lookups() + .iter() + .map(>::as_string) + .collect_vec() + ))); + fields + } +} + +impl PartialEq for Node { + fn eq(&self, other: &Node) -> bool { + (self.out_scale == other.out_scale) + && (self.inputs == other.inputs) + && (self.out_dims == other.out_dims) + && (self.idx == other.idx) + && (self.opkind.as_string() == other.opkind.as_string()) + } +} + +impl Node { + /// Converts a tract [OnnxNode] into an ezkl [Node]. + /// # Arguments: + /// * `node` - [OnnxNode] + /// * `other_nodes` - [BTreeMap] of other previously initialized [Node]s in the computational graph. + /// * `public_params` - flag if parameters of model are public + /// * `idx` - The node's unique identifier. + #[cfg(not(target_arch = "wasm32"))] + pub fn new( + node: OnnxNode>, + other_nodes: &mut BTreeMap, + scales: &VarScales, + param_visibility: &Visibility, + idx: usize, + symbol_values: &SymbolValues, + ) -> Result> { + trace!("Create {:?}", node); + trace!("Create op {:?}", node.op); + + let num_uses = std::cmp::max( + node.outputs + .iter() + .map(|outlet| outlet.successors.len()) + .sum::(), + // cmp to 1 for outputs + 1, + ); + + // load the node inputs + let mut inputs = vec![]; + + // we can only take the inputs as mutable once -- so we need to collect them first + let mut input_ids = node + .inputs + .iter() + .map(|i| (i.node, i.slot)) + .collect::>(); + + input_ids + .iter() + .map(|(i, _)| Ok(inputs.push(other_nodes.get(i).ok_or("input not found")?.clone()))) + .collect::, Box>>()?; + + let (mut opkind, deleted_indices) = new_op_from_onnx( + idx, + scales, + param_visibility, + node.clone(), + &mut inputs, + symbol_values, + )?; // parses the op name + + // we can only take the inputs as mutable once -- so we need to collect them first + other_nodes.extend( + inputs + .iter() + .map(|i| (i.idx(), i.clone())) + .collect::>(), + ); + + input_ids.iter_mut().enumerate().for_each(|(i, (idx, _))| { + if deleted_indices.contains(&i) { + // this input is not used + *idx = usize::MAX; + } + }); + + // remove the inputs that are not used + input_ids.retain(|(idx, _)| *idx != usize::MAX); + + // rescale the inputs if necessary to get consistent fixed points + let mut in_scales: Vec = input_ids + .iter() + .map(|(idx, outlet)| { + let idx = inputs + .iter() + .position(|x| *idx == x.idx()) + .ok_or("input not found")?; + Ok(inputs[idx].out_scales()[*outlet]) + }) + .collect::, Box>>()?; + + let homogenous_inputs = opkind.requires_homogenous_input_scales(); + // autoamtically increases a constant's scale if it is only used once and + for input in homogenous_inputs + .into_iter() + .filter(|i| !deleted_indices.contains(i)) + { + let input_node = other_nodes + .get_mut(&inputs[input].idx()) + .ok_or("input not found")?; + let input_opkind = &mut input_node.opkind(); + if let Some(constant) = input_opkind.get_mutable_constant() { + rescale_const_with_single_use( + constant, + in_scales.clone(), + param_visibility, + input_node.num_uses(), + )?; + input_node.replace_opkind(constant.clone_dyn().into()); + let out_scale = input_opkind.out_scale(vec![])?; + input_node.bump_scale(out_scale); + in_scales[input] = out_scale; + } + } + + opkind = opkind.homogenous_rescale(in_scales.clone())?.into(); + let mut out_scale = opkind.out_scale(in_scales.clone())?; + // rescale the inputs if necessary to get consistent fixed points, we select the largest scale (highest precision) + let global_scale = scales.get_max(); + opkind = RebaseScale::rebase(opkind, global_scale, out_scale, scales.rebase_multiplier); + + out_scale = opkind.out_scale(in_scales)?; + + // get the output shape + let mut out_dims = { + let output_shapes = match node_output_shapes(&node) { + Ok(s) => Some(s), + _ => None, + }; + + if let Some([Some(v)]) = output_shapes.as_deref() { + v.to_vec() + } else if let Some([Some(v), Some(_)]) = output_shapes.as_deref() { + v.to_vec() + } else { + return Err("could not get output shape for node".into()); + } + }; + + if out_dims.is_empty() { + out_dims = vec![1]; + } + + Ok(Node { + idx, + opkind, + inputs: input_ids, + out_dims, + out_scale, + num_uses, + }) + } +} + +#[cfg(not(target_arch = "wasm32"))] +fn rescale_const_with_single_use( + constant: &mut Constant, + in_scales: Vec, + param_visibility: &Visibility, + num_uses: usize, +) -> Result<(), Box> { + if num_uses == 1 { + let current_scale = constant.out_scale(vec![])?; + let scale_max = in_scales.iter().max().ok_or("no scales")?; + if scale_max > ¤t_scale { + let raw_values = constant.raw_values.clone(); + constant.quantized_values = + super::quantize_tensor(raw_values, *scale_max, param_visibility)?; + } + } + + Ok(()) +} diff --git a/mnist_ezkl/src/graph/utilities.rs b/mnist_ezkl/src/graph/utilities.rs new file mode 100644 index 0000000..1f51d30 --- /dev/null +++ b/mnist_ezkl/src/graph/utilities.rs @@ -0,0 +1,1449 @@ +#[cfg(not(target_arch = "wasm32"))] +use super::GraphError; +#[cfg(not(target_arch = "wasm32"))] +use super::VarScales; +use super::{Rescaled, SupportedOp, Visibility}; +#[cfg(not(target_arch = "wasm32"))] +use crate::circuit::hybrid::HybridOp; +#[cfg(not(target_arch = "wasm32"))] +use crate::circuit::lookup::LookupOp; +use crate::circuit::poly::PolyOp; +use crate::circuit::Op; +use crate::tensor::{Tensor, TensorError, TensorType}; +use halo2curves::bn256::Fr as Fp; +use halo2curves::ff::PrimeField; +use itertools::Itertools; +#[cfg(not(target_arch = "wasm32"))] +use log::{debug, warn}; +use std::error::Error; +#[cfg(not(target_arch = "wasm32"))] +use std::sync::Arc; +#[cfg(not(target_arch = "wasm32"))] +use tract_onnx::prelude::{DatumType, Node as OnnxNode, TypedFact, TypedOp}; +#[cfg(not(target_arch = "wasm32"))] +use tract_onnx::tract_core::ops::{ + array::{Gather, GatherElements, OneHot, ScatterElements, Slice, Topk}, + change_axes::AxisOp, + cnn::DeconvUnary, + einsum::EinSum, + element_wise::ElementWiseOp, + nn::{LeakyRelu, Reduce, Softmax}, + Downsample, +}; +#[cfg(not(target_arch = "wasm32"))] +use tract_onnx::tract_hir::{ + internal::DimLike, + ops::array::{Pad, PadMode, TypedConcat}, + ops::cnn::{ConvUnary, PoolSpec}, + ops::konst::Const, + ops::nn::DataFormat, + tract_core::ops::cast::Cast, + tract_core::ops::cnn::{conv::KernelFormat, MaxPool, PaddingSpec, SumPool}, +}; + +// Warning: currently ignores stride information +/// Quantizes an iterable of f32s to a [Tensor] of i32s using a fixed point representation. +/// Arguments +/// +/// * `vec` - the vector to quantize. +/// * `dims` - the dimensionality of the resulting [Tensor]. +/// * `shift` - offset used in the fixed point representation. +/// * `scale` - `2^scale` used in the fixed point representation. +pub fn quantize_float(elem: &f64, shift: f64, scale: crate::Scale) -> Result { + let mult = scale_to_multiplier(scale); + let max_value = ((i128::MAX as f64 - shift) / mult).round(); // the maximum value that can be represented w/o sig bit truncation + + if *elem > max_value { + return Err(TensorError::SigBitTruncationError); + } + + // we parallelize the quantization process as it seems to be quite slow at times + let scaled = (mult * *elem + shift).round() as i128; + + Ok(scaled) +} + +/// Converts a scale (log base 2) to a fixed point multiplier. +pub fn scale_to_multiplier(scale: crate::Scale) -> f64 { + f64::powf(2., scale as f64) +} + +/// Converts a scale (log base 2) to a fixed point multiplier. +pub fn multiplier_to_scale(mult: f64) -> crate::Scale { + mult.log2().round() as crate::Scale +} + +/// Gets the shape of a onnx node's outlets. +#[cfg(not(target_arch = "wasm32"))] +pub fn node_output_shapes( + node: &OnnxNode>, +) -> Result>>, Box> { + let mut shapes = Vec::new(); + let outputs = node.outputs.to_vec(); + for output in outputs { + let mv = output.fact.shape.clone().as_concrete().map(|x| x.to_vec()); + shapes.push(mv) + } + Ok(shapes) +} +#[cfg(not(target_arch = "wasm32"))] +use tract_onnx::prelude::SymbolValues; +#[cfg(not(target_arch = "wasm32"))] +fn extract_tensor_value( + input: Arc, + symbol_values: &SymbolValues, +) -> Result, Box> { + use rayon::prelude::{IntoParallelRefIterator, ParallelIterator}; + + let dt = input.datum_type(); + let dims = input.shape().to_vec(); + + let mut const_value: Tensor; + if dims.is_empty() && input.len() == 0 { + const_value = Tensor::::new(None, &dims)?; + return Ok(const_value); + } + + match dt { + DatumType::F16 => { + let vec = input.as_slice::()?.to_vec(); + let cast: Vec = vec.par_iter().map(|x| (*x).into()).collect(); + const_value = Tensor::::new(Some(&cast), &dims)?; + } + DatumType::F32 => { + let vec = input.as_slice::()?.to_vec(); + const_value = Tensor::::new(Some(&vec), &dims)?; + } + DatumType::F64 => { + let vec = input.as_slice::()?.to_vec(); + let cast: Vec = vec.par_iter().map(|x| *x as f32).collect(); + const_value = Tensor::::new(Some(&cast), &dims)?; + } + DatumType::I64 => { + // Generally a shape or hyperparam + let vec = input.as_slice::()?.to_vec(); + let cast: Vec = vec.par_iter().map(|x| *x as f32).collect(); + const_value = Tensor::::new(Some(&cast), &dims)?; + } + DatumType::I32 => { + // Generally a shape or hyperparam + let vec = input.as_slice::()?.to_vec(); + let cast: Vec = vec.par_iter().map(|x| *x as f32).collect(); + const_value = Tensor::::new(Some(&cast), &dims)?; + } + DatumType::I16 => { + // Generally a shape or hyperparam + let vec = input.as_slice::()?.to_vec(); + let cast: Vec = vec.par_iter().map(|x| *x as f32).collect(); + const_value = Tensor::::new(Some(&cast), &dims)?; + } + DatumType::I8 => { + // Generally a shape or hyperparam + let vec = input.as_slice::()?.to_vec(); + let cast: Vec = vec.par_iter().map(|x| *x as f32).collect(); + const_value = Tensor::::new(Some(&cast), &dims)?; + } + DatumType::U8 => { + // Generally a shape or hyperparam + let vec = input.as_slice::()?.to_vec(); + let cast: Vec = vec.par_iter().map(|x| *x as f32).collect(); + const_value = Tensor::::new(Some(&cast), &dims)?; + } + DatumType::U16 => { + // Generally a shape or hyperparam + let vec = input.as_slice::()?.to_vec(); + let cast: Vec = vec.par_iter().map(|x| *x as f32).collect(); + const_value = Tensor::::new(Some(&cast), &dims)?; + } + DatumType::U32 => { + // Generally a shape or hyperparam + let vec = input.as_slice::()?.to_vec(); + let cast: Vec = vec.par_iter().map(|x| *x as f32).collect(); + const_value = Tensor::::new(Some(&cast), &dims)?; + } + DatumType::U64 => { + // Generally a shape or hyperparam + let vec = input.as_slice::()?.to_vec(); + let cast: Vec = vec.par_iter().map(|x| *x as f32).collect(); + const_value = Tensor::::new(Some(&cast), &dims)?; + } + DatumType::Bool => { + // Generally a shape or hyperparam + let vec = input.as_slice::()?.to_vec(); + let cast: Vec = vec.par_iter().map(|x| *x as usize as f32).collect(); + const_value = Tensor::::new(Some(&cast), &dims)?; + } + DatumType::TDim => { + // Generally a shape or hyperparam + let vec = input.as_slice::()?.to_vec(); + + let cast: Result, &str> = vec + .par_iter() + .map(|x| match x.to_i64() { + Ok(v) => Ok(v as f32), + Err(_) => match x.eval(symbol_values).to_i64() { + Ok(v) => Ok(v as f32), + Err(_) => Err("could not evaluate tdim"), + }, + }) + .collect(); + + const_value = Tensor::::new(Some(&cast?), &dims)?; + } + _ => return Err("unsupported data type".into()), + } + const_value.reshape(&dims)?; + + Ok(const_value) +} + +#[cfg(not(target_arch = "wasm32"))] +fn load_op( + op: &dyn tract_onnx::prelude::Op, + idx: usize, + name: String, +) -> Result> { + // Extract the slope layer hyperparams + let op: &C = match op.downcast_ref::() { + Some(b) => b, + None => { + return Err(Box::new(GraphError::OpMismatch(idx, name))); + } + }; + + Ok(op.clone()) +} + +/// Matches an onnx node to a [crate::circuit::Op]. +/// Arguments +/// * `idx` - the index of the node in the graph. +/// * `scale` - the global (circuit) scale. +/// * `param_visibility` - [Visibility] of the node. +/// * `node` - the [OnnxNode] to be matched. +/// * `inputs` - the node's inputs. +#[cfg(not(target_arch = "wasm32"))] +pub fn new_op_from_onnx( + idx: usize, + scales: &VarScales, + param_visibility: &Visibility, + node: OnnxNode>, + inputs: &mut [super::NodeType], + symbol_values: &SymbolValues, +) -> Result<(SupportedOp, Vec), Box> { + use crate::circuit::InputType; + + debug!("Loading node: {:?}", node); + let mut deleted_indices = vec![]; + let node = match node.op().name().as_ref() { + "Range" => { + let mut input_ops = vec![]; + + for (i, input) in inputs.iter_mut().enumerate() { + if !input.opkind().is_constant() { + return Err("Range only supports constant inputs in a zk circuit".into()); + } else { + input.decrement_use(); + deleted_indices.push(i); + input_ops.push(input.opkind().clone()); + } + } + + assert_eq!(input_ops.len(), 3, "Range requires 3 inputs"); + let input_ops = input_ops + .iter() + .map(|x| x.get_constant().ok_or("Range requires constant inputs")) + .collect::, _>>()?; + + let start = input_ops[0].raw_values.map(|x| x as usize)[0]; + let end = input_ops[1].raw_values.map(|x| x as usize)[0]; + let delta = input_ops[2].raw_values.map(|x| x as usize)[0]; + + let range = (start..end).step_by(delta).collect::>(); + let raw_value = range.iter().map(|x| *x as f32).collect::>(); + // Quantize the raw value (integers) + let quantized_value = quantize_tensor(raw_value.clone(), 0, &Visibility::Fixed)?; + + let c = crate::circuit::ops::Constant::new(quantized_value, raw_value); + // Create a constant op + SupportedOp::Constant(c) + } + + "Gather" => { + if inputs.len() != 2 { + return Err(Box::new(GraphError::InvalidDims(idx, "gather".to_string()))); + }; + let op = load_op::(node.op(), idx, node.op().name().to_string())?; + let axis = op.axis; + + let mut op = SupportedOp::Hybrid(crate::circuit::ops::hybrid::HybridOp::Gather { + dim: axis, + constant_idx: None, + }); + + // if param_visibility.is_public() { + if let Some(c) = inputs[1].opkind().get_mutable_constant() { + inputs[1].decrement_use(); + deleted_indices.push(inputs.len() - 1); + op = SupportedOp::Hybrid(crate::circuit::ops::hybrid::HybridOp::Gather { + dim: axis, + constant_idx: Some(c.raw_values.map(|x| x as usize)), + }); + } + // } + + if inputs[1].opkind().is_input() { + inputs[1].replace_opkind(SupportedOp::Input(crate::circuit::ops::Input { + scale: 0, + datum_type: InputType::TDim, + })); + inputs[1].bump_scale(0); + } + + op + + // Extract the max value + } + "Topk" => { + let op = load_op::(node.op(), idx, node.op().name().to_string())?; + let axis = op.axis; + // if param_visibility.is_public() { + let k = if let Some(c) = inputs[1].opkind().get_mutable_constant() { + inputs[1].decrement_use(); + deleted_indices.push(inputs.len() - 1); + c.raw_values.map(|x| x as usize)[0] + } else { + op.fallback_k.to_i64()? as usize + }; + + SupportedOp::Hybrid(crate::circuit::ops::hybrid::HybridOp::TopK { dim: axis, k }) + } + "Onehot" => { + let op = load_op::(node.op(), idx, node.op().name().to_string())?; + let axis = op.axis; + let num_classes = op.dim; + + SupportedOp::Hybrid(crate::circuit::ops::hybrid::HybridOp::OneHot { + dim: axis, + num_classes, + }) + } + "ScatterElements" => { + if inputs.len() != 3 { + return Err(Box::new(GraphError::InvalidDims( + idx, + "scatter elements".to_string(), + ))); + }; + let op = load_op::(node.op(), idx, node.op().name().to_string())?; + let axis = op.axis; + + let mut op = + SupportedOp::Hybrid(crate::circuit::ops::hybrid::HybridOp::ScatterElements { + dim: axis, + constant_idx: None, + }); + + // if param_visibility.is_public() { + if let Some(c) = inputs[1].opkind().get_mutable_constant() { + inputs[1].decrement_use(); + deleted_indices.push(1); + op = SupportedOp::Hybrid(crate::circuit::ops::hybrid::HybridOp::ScatterElements { + dim: axis, + constant_idx: Some(c.raw_values.map(|x| x as usize)), + }) + } + // } + + if inputs[1].opkind().is_input() { + inputs[1].replace_opkind(SupportedOp::Input(crate::circuit::ops::Input { + scale: 0, + datum_type: InputType::TDim, + })); + inputs[1].bump_scale(0); + } + + op + + // Extract the max value + } + "GatherElements" => { + if inputs.len() != 2 { + return Err(Box::new(GraphError::InvalidDims( + idx, + "gather elements".to_string(), + ))); + }; + let op = load_op::(node.op(), idx, node.op().name().to_string())?; + let axis = op.axis; + + let mut op = + SupportedOp::Hybrid(crate::circuit::ops::hybrid::HybridOp::GatherElements { + dim: axis, + constant_idx: None, + }); + + // if param_visibility.is_public() { + if let Some(c) = inputs[1].opkind().get_mutable_constant() { + inputs[1].decrement_use(); + deleted_indices.push(inputs.len() - 1); + op = SupportedOp::Hybrid(crate::circuit::ops::hybrid::HybridOp::GatherElements { + dim: axis, + constant_idx: Some(c.raw_values.map(|x| x as usize)), + }) + } + // } + + if inputs[1].opkind().is_input() { + inputs[1].replace_opkind(SupportedOp::Input(crate::circuit::ops::Input { + scale: 0, + datum_type: InputType::TDim, + })); + inputs[1].bump_scale(0); + } + + op + + // Extract the max value + } + "MoveAxis" => { + let op = load_op::(node.op(), idx, node.op().name().to_string())?; + match op { + AxisOp::Move(from, to) => { + let source = from.to_usize()?; + let destination = to.to_usize()?; + SupportedOp::Linear(crate::circuit::ops::poly::PolyOp::MoveAxis { + source, + destination, + }) + } + + _ => { + return Err(Box::new(GraphError::OpMismatch( + idx, + "MoveAxis".to_string(), + ))) + } + } + } + "Concat" | "InferenceConcat" => { + let op = load_op::(node.op(), idx, node.op().name().to_string())?; + let axis = op.axis; + SupportedOp::Linear(crate::circuit::ops::poly::PolyOp::Concat { axis }) + } + "Slice" => { + let slice = load_op::(node.op(), idx, node.op().name().to_string())?; + + let axis = slice.axis; + let start = slice.start.to_usize()?; + let end = slice.end.to_usize()?; + + SupportedOp::Linear(PolyOp::Slice { axis, start, end }) + } + "Const" => { + let op: Const = load_op::(node.op(), idx, node.op().name().to_string())?; + let dt = op.0.datum_type(); + // Raw values are always f32 + let raw_value = extract_tensor_value(op.0, symbol_values)?; + // If bool or a tensor dimension then don't scale + let constant_scale = match dt { + DatumType::Bool + | DatumType::TDim + | DatumType::I64 + | DatumType::I32 + | DatumType::I16 + | DatumType::I8 + | DatumType::U8 + | DatumType::U16 + | DatumType::U32 + | DatumType::U64 => 0, + DatumType::F16 | DatumType::F32 | DatumType::F64 => scales.params, + _ => return Err(Box::new(GraphError::UnsupportedDataType)), + }; + + // Quantize the raw value + let quantized_value = + quantize_tensor(raw_value.clone(), constant_scale, param_visibility)?; + let c = crate::circuit::ops::Constant::new(quantized_value, raw_value); + // Create a constant op + SupportedOp::Constant(c) + } + "Reduce" => { + if inputs.len() != 1 { + return Err(Box::new(GraphError::InvalidDims(idx, "argmax".to_string()))); + }; + let op = load_op::(node.op(), idx, node.op().name().to_string())?; + let axes: Vec = op.axes.into_iter().collect(); + assert_eq!(axes.len(), 1, "only support argmax over one axis"); + + SupportedOp::Hybrid(HybridOp::ReduceArgMax { dim: axes[0] }) + } + "Reduce" => { + if inputs.len() != 1 { + return Err(Box::new(GraphError::InvalidDims(idx, "argmin".to_string()))); + }; + let op = load_op::(node.op(), idx, node.op().name().to_string())?; + let axes: Vec = op.axes.into_iter().collect(); + assert_eq!(axes.len(), 1, "only support argmin over one axis"); + + SupportedOp::Hybrid(HybridOp::ReduceArgMin { dim: axes[0] }) + } + "Reduce" => { + if inputs.len() != 1 { + return Err(Box::new(GraphError::InvalidDims(idx, "min".to_string()))); + }; + let op = load_op::(node.op(), idx, node.op().name().to_string())?; + let axes = op.axes.into_iter().collect(); + + SupportedOp::Hybrid(HybridOp::ReduceMin { axes }) + } + "Reduce" => { + if inputs.len() != 1 { + return Err(Box::new(GraphError::InvalidDims(idx, "max".to_string()))); + }; + let op = load_op::(node.op(), idx, node.op().name().to_string())?; + let axes = op.axes.into_iter().collect(); + + SupportedOp::Hybrid(HybridOp::ReduceMax { axes }) + } + "Reduce" => { + if inputs.len() != 1 { + return Err(Box::new(GraphError::InvalidDims(idx, "prod".to_string()))); + }; + let op = load_op::(node.op(), idx, node.op().name().to_string())?; + let axes: Vec = op.axes.into_iter().collect(); + + // length of prod along axes + let len_prod = inputs[0].out_dims()[0] + .iter() + .enumerate() + .filter(|(i, _)| axes.contains(i)) + .map(|(_, v)| v) + .product::(); + + SupportedOp::Linear(PolyOp::Prod { axes, len_prod }) + } + "Reduce" => { + if inputs.len() != 1 { + return Err(Box::new(GraphError::InvalidDims(idx, "sum".to_string()))); + }; + let op = load_op::(node.op(), idx, node.op().name().to_string())?; + let axes = op.axes.into_iter().collect(); + + SupportedOp::Linear(PolyOp::Sum { axes }) + } + "Max" => { + // Extract the max value + // first find the input that is a constant + // and then extract the value + let const_inputs = inputs + .iter() + .enumerate() + .filter(|(_, n)| n.is_constant()) + .map(|(i, _)| i) + .collect::>(); + + if const_inputs.len() != 1 { + return Err(Box::new(GraphError::OpMismatch(idx, "Max".to_string()))); + } + + let const_idx = const_inputs[0]; + let boxed_op = inputs[const_idx].opkind(); + let unit = if let Some(c) = extract_const_raw_values(boxed_op) { + if c.len() == 1 { + c[0] + } else { + return Err(Box::new(GraphError::InvalidDims(idx, "max".to_string()))); + } + } else { + return Err(Box::new(GraphError::OpMismatch(idx, "Max".to_string()))); + }; + + if inputs.len() == 2 { + if let Some(node) = inputs.get_mut(const_idx) { + node.decrement_use(); + deleted_indices.push(const_idx); + } + if unit == 0. { + SupportedOp::Nonlinear(LookupOp::ReLU) + } else { + SupportedOp::Nonlinear(LookupOp::Max { + scales: (1, 1), + a: crate::circuit::utils::F32(unit), + }) + } + } else { + return Err(Box::new(GraphError::InvalidDims(idx, "max".to_string()))); + } + } + "Min" => { + // Extract the min value + // first find the input that is a constant + // and then extract the value + let const_inputs = inputs + .iter() + .enumerate() + .filter(|(_, n)| n.is_constant()) + .map(|(i, _)| i) + .collect::>(); + + if const_inputs.len() != 1 { + return Err(Box::new(GraphError::OpMismatch(idx, "Min".to_string()))); + } + + let const_idx = const_inputs[0]; + let boxed_op = inputs[const_idx].opkind(); + let unit = if let Some(c) = extract_const_raw_values(boxed_op) { + if c.len() == 1 { + c[0] + } else { + return Err(Box::new(GraphError::InvalidDims(idx, "min".to_string()))); + } + } else { + return Err(Box::new(GraphError::OpMismatch(idx, "Min".to_string()))); + }; + + if inputs.len() == 2 { + if let Some(node) = inputs.get_mut(const_idx) { + node.decrement_use(); + deleted_indices.push(const_idx); + } + SupportedOp::Nonlinear(LookupOp::Min { + scales: (1, 1), + a: crate::circuit::utils::F32(unit), + }) + } else { + return Err(Box::new(GraphError::InvalidDims(idx, "min".to_string()))); + } + } + "Recip" => { + // Extract the slope layer hyperparams + let in_scale = inputs[0].out_scales()[0]; + let scale_diff = std::cmp::max(scales.input, scales.params) - inputs[0].out_scales()[0]; + let additional_scale = if scale_diff > 0 { + scale_to_multiplier(scale_diff) + } else { + 1.0 + }; + + SupportedOp::Nonlinear(LookupOp::Recip { + scale: (scale_to_multiplier(in_scale).powf(2.0) * additional_scale).into(), + }) + } + + "LeakyRelu" => { + // Extract the slope layer hyperparams + let leaky_op = load_op::(node.op(), idx, node.op().name().to_string())?; + + let leaky_op: &LeakyRelu = match leaky_op.0.downcast_ref::() { + Some(b) => b, + None => { + return Err(Box::new(GraphError::OpMismatch( + idx, + "leaky relu".to_string(), + ))); + } + }; + + SupportedOp::Nonlinear(LookupOp::LeakyReLU { + slope: crate::circuit::utils::F32(leaky_op.alpha), + }) + } + "Scan" => { + return Err("scan should never be analyzed explicitly".into()); + } + "QuantizeLinearU8" | "DequantizeLinearF32" => SupportedOp::Linear(PolyOp::Identity), + "Abs" => SupportedOp::Nonlinear(LookupOp::Abs), + "Neg" => SupportedOp::Linear(PolyOp::Neg), + "Sigmoid" => SupportedOp::Nonlinear(LookupOp::Sigmoid { + scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(), + }), + "Sqrt" => SupportedOp::Nonlinear(LookupOp::Sqrt { + scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(), + }), + "Rsqrt" => SupportedOp::Nonlinear(LookupOp::Rsqrt { + scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(), + }), + "Exp" => SupportedOp::Nonlinear(LookupOp::Exp { + scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(), + }), + "Ln" => SupportedOp::Nonlinear(LookupOp::Ln { + scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(), + }), + "Sin" => SupportedOp::Nonlinear(LookupOp::Sin { + scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(), + }), + "Cos" => SupportedOp::Nonlinear(LookupOp::Cos { + scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(), + }), + "Tan" => SupportedOp::Nonlinear(LookupOp::Tan { + scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(), + }), + "Asin" => SupportedOp::Nonlinear(LookupOp::ASin { + scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(), + }), + "Acos" => SupportedOp::Nonlinear(LookupOp::ACos { + scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(), + }), + "Atan" => SupportedOp::Nonlinear(LookupOp::ATan { + scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(), + }), + "Sinh" => SupportedOp::Nonlinear(LookupOp::Sinh { + scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(), + }), + "Cosh" => SupportedOp::Nonlinear(LookupOp::Cosh { + scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(), + }), + "Tanh" => SupportedOp::Nonlinear(LookupOp::Tanh { + scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(), + }), + "Asinh" => SupportedOp::Nonlinear(LookupOp::ASinh { + scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(), + }), + "Acosh" => SupportedOp::Nonlinear(LookupOp::ACosh { + scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(), + }), + "Atanh" => SupportedOp::Nonlinear(LookupOp::ATanh { + scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(), + }), + "Erf" => SupportedOp::Nonlinear(LookupOp::Erf { + scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(), + }), + "Source" => { + let (scale, datum_type) = match node.outputs[0].fact.datum_type { + DatumType::Bool => (0, InputType::Bool), + DatumType::TDim => (0, InputType::TDim), + DatumType::I64 + | DatumType::I32 + | DatumType::I16 + | DatumType::I8 + | DatumType::U8 + | DatumType::U16 + | DatumType::U32 + | DatumType::U64 => (0, InputType::Int), + DatumType::F16 => (scales.input, InputType::F16), + DatumType::F32 => (scales.input, InputType::F32), + DatumType::F64 => (scales.input, InputType::F64), + _ => return Err(Box::new(GraphError::UnsupportedDataType)), + }; + SupportedOp::Input(crate::circuit::ops::Input { scale, datum_type }) + } + "Cast" => { + let op = load_op::(node.op(), idx, node.op().name().to_string())?; + let dt = op.to; + let input_scales = inputs + .iter() + .flat_map(|x| x.out_scales()) + .collect::>(); + assert_eq!(input_scales.len(), 1); + + let mut constant = inputs[0].opkind(); + let constant = constant.get_mutable_constant(); + + let replace_const = |scale: crate::Scale, + default_op: SupportedOp| + -> Result> { + if let Some(c) = constant { + inputs[0].bump_scale(scale); + c.rebase_scale(scale)?; + inputs[0].replace_opkind(SupportedOp::Constant(c.clone())); + Ok(SupportedOp::Linear(PolyOp::Identity)) + } else { + Ok(default_op) + } + }; + + match dt { + DatumType::Bool + | DatumType::TDim + | DatumType::I64 + | DatumType::I32 + | DatumType::I16 + | DatumType::I8 + | DatumType::U8 + | DatumType::U16 + | DatumType::U32 + | DatumType::U64 => { + if input_scales[0] != 0 { + replace_const( + 0, + SupportedOp::Nonlinear(LookupOp::Div { + denom: crate::circuit::utils::F32(scale_to_multiplier( + input_scales[0], + ) + as f32), + }), + )? + } else { + SupportedOp::Linear(PolyOp::Identity) + } + } + DatumType::F16 | DatumType::F32 | DatumType::F64 => { + SupportedOp::Linear(PolyOp::Identity) + } + _ => return Err(Box::new(GraphError::UnsupportedDataType)), + } + } + "Add" => SupportedOp::Linear(PolyOp::Add), + "Sub" => SupportedOp::Linear(PolyOp::Sub), + "Mul" => { + let mut op = SupportedOp::Linear(PolyOp::Mult); + + let const_idx = inputs + .iter() + .enumerate() + .filter(|(_, n)| n.is_constant()) + .map(|(i, _)| i) + .collect::>(); + + if !(const_idx.len() <= 1) { + return Err(Box::new(GraphError::InvalidDims(idx, "mul".to_string()))); + } + + if const_idx.len() == 1 { + let const_idx = const_idx[0]; + if let Some(c) = inputs[const_idx].opkind().get_mutable_constant() { + if c.raw_values.len() == 1 && c.raw_values[0] < 1. { + inputs[const_idx].decrement_use(); + deleted_indices.push(const_idx); + op = SupportedOp::Nonlinear(LookupOp::Div { + // we invert the constant for division + denom: crate::circuit::utils::F32(1. / c.raw_values[0]), + }) + } + } + } + op + } + "Iff" => SupportedOp::Linear(PolyOp::Iff), + "Less" => { + if inputs.len() == 2 { + SupportedOp::Hybrid(HybridOp::Less) + } else { + return Err(Box::new(GraphError::InvalidDims(idx, "less".to_string()))); + } + } + "LessEqual" => { + if inputs.len() == 2 { + SupportedOp::Hybrid(HybridOp::LessEqual) + } else { + return Err(Box::new(GraphError::InvalidDims( + idx, + "less equal".to_string(), + ))); + } + } + "Greater" => { + // Extract the slope layer hyperparams + if inputs.len() == 2 { + SupportedOp::Hybrid(HybridOp::Greater) + } else { + return Err(Box::new(GraphError::InvalidDims( + idx, + "greater".to_string(), + ))); + } + } + "GreaterEqual" => { + // Extract the slope layer hyperparams + if inputs.len() == 2 { + SupportedOp::Hybrid(HybridOp::GreaterEqual) + } else { + return Err(Box::new(GraphError::InvalidDims( + idx, + "greater equal".to_string(), + ))); + } + } + "EinSum" => { + // Extract the slope layer hyperparams + let op: &EinSum = match node.op().downcast_ref::() { + Some(b) => b, + None => { + return Err(Box::new(GraphError::OpMismatch(idx, "einsum".to_string()))); + } + }; + + let axes = &op.axes; + SupportedOp::Linear(PolyOp::Einsum { + equation: axes.to_string(), + }) + } + "Softmax" => { + // Extract the slope layer hyperparams + let softmax_op: &Softmax = match node.op().downcast_ref::() { + Some(b) => b, + None => { + return Err(Box::new(GraphError::OpMismatch(idx, "softmax".to_string()))); + } + }; + + SupportedOp::Hybrid(HybridOp::Softmax { + scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(), + axes: softmax_op.axes.to_vec(), + }) + } + "MaxPool" => { + // Extract the padding and stride layer hyperparams + let op = Box::new(node.op()); + let sumpool_node: &MaxPool = match op.downcast_ref() { + Some(b) => b, + None => { + return Err(Box::new(GraphError::OpMismatch(idx, "Maxpool".to_string()))); + } + }; + + let pool_spec: &PoolSpec = &sumpool_node.pool_spec; + + // only support pytorch type formatting for now + if pool_spec.data_format != DataFormat::NCHW { + return Err(Box::new(GraphError::MissingParams( + "data in wrong format".to_string(), + ))); + } + + let stride = pool_spec + .strides + .clone() + .ok_or(GraphError::MissingParams("stride".to_string()))?; + let padding = match &pool_spec.padding { + PaddingSpec::Explicit(b, a) => [(b[0], b[1]), (a[0], a[1])], + PaddingSpec::ExplicitOnnxPool(b, a, _) => [(b[0], b[1]), (a[0], a[1])], + _ => { + return Err(Box::new(GraphError::MissingParams("padding".to_string()))); + } + }; + let kernel_shape = &pool_spec.kernel_shape; + + let (stride_h, stride_w) = (stride[0], stride[1]); + let (kernel_height, kernel_width) = (kernel_shape[0], kernel_shape[1]); + + SupportedOp::Hybrid(HybridOp::MaxPool2d { + padding, + stride: (stride_h, stride_w), + pool_dims: (kernel_height, kernel_width), + }) + } + "Ceil" => SupportedOp::Nonlinear(LookupOp::Ceil { + scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(), + }), + "Floor" => SupportedOp::Nonlinear(LookupOp::Floor { + scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(), + }), + "Round" => SupportedOp::Nonlinear(LookupOp::Round { + scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(), + }), + "RoundHalfToEven" => SupportedOp::Nonlinear(LookupOp::RoundHalfToEven { + scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(), + }), + "Sign" => SupportedOp::Nonlinear(LookupOp::Sign), + "Pow" => { + // Extract the slope layer hyperparams from a const + + // if param_visibility.is_public() { + if let Some(c) = inputs[1].opkind().get_mutable_constant() { + inputs[1].decrement_use(); + deleted_indices.push(inputs.len() - 1); + if c.raw_values.len() > 1 { + unimplemented!("only support scalar pow") + } + SupportedOp::Nonlinear(LookupOp::Pow { + scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(), + a: crate::circuit::utils::F32(c.raw_values[0]), + }) + } else { + unimplemented!("only support constant pow for now") + } + } + "Cube" => SupportedOp::Linear(PolyOp::Pow(3)), + "Square" => SupportedOp::Linear(PolyOp::Pow(2)), + "ConvUnary" => { + let conv_node: &ConvUnary = match node.op().downcast_ref::() { + Some(b) => b, + None => { + return Err(Box::new(GraphError::OpMismatch(idx, "conv".to_string()))); + } + }; + + if let Some(dilations) = &conv_node.pool_spec.dilations { + if dilations.iter().any(|x| *x != 1) { + return Err(Box::new(GraphError::MisformedParams( + "non unit dilations not supported".to_string(), + ))); + } + } + + if ((conv_node.pool_spec.data_format != DataFormat::NCHW) + && (conv_node.pool_spec.data_format != DataFormat::CHW)) + || (conv_node.kernel_fmt != KernelFormat::OIHW) + { + return Err(Box::new(GraphError::MisformedParams( + "data or kernel in wrong format".to_string(), + ))); + } + + let stride = match conv_node.pool_spec.strides.clone() { + Some(s) => { + if s.len() == 1 { + (s[0], s[0]) + } else if s.len() == 2 { + (s[0], s[1]) + } else { + return Err(Box::new(GraphError::MissingParams("strides".to_string()))); + } + } + None => { + return Err(Box::new(GraphError::MissingParams("strides".to_string()))); + } + }; + + let padding = match &conv_node.pool_spec.padding { + PaddingSpec::Explicit(b, a) | PaddingSpec::ExplicitOnnxPool(b, a, _) => { + if b.len() == 2 && a.len() == 2 { + [(b[0], b[1]), (a[0], a[1])] + } else if b.len() == 1 && a.len() == 1 { + [(b[0], b[0]), (a[0], a[0])] + } else if b.len() == 1 && a.len() == 2 { + [(b[0], b[0]), (a[0], a[1])] + } else if b.len() == 2 && a.len() == 1 { + [(b[0], b[1]), (a[0], a[0])] + } else { + return Err(Box::new(GraphError::MissingParams("padding".to_string()))); + } + } + _ => { + return Err(Box::new(GraphError::MissingParams("padding".to_string()))); + } + }; + + let kernel = extract_tensor_value(conv_node.kernel.clone(), symbol_values)?; + let kernel = quantize_tensor(kernel, scales.params, param_visibility)?; + + let bias = match conv_node.bias.clone() { + Some(b) => { + let const_value = extract_tensor_value(b, symbol_values)?; + + let val = quantize_tensor( + const_value, + scales.params + inputs[0].out_scales()[0], + param_visibility, + )?; + Some(val) + } + None => None, + }; + + SupportedOp::Linear(PolyOp::Conv { + kernel, + bias, + padding, + stride, + }) + } + "Not" => SupportedOp::Linear(PolyOp::Not), + "And" => SupportedOp::Linear(PolyOp::And), + "Or" => SupportedOp::Linear(PolyOp::Or), + "Xor" => SupportedOp::Linear(PolyOp::Xor), + "Equals" => SupportedOp::Hybrid(HybridOp::Equals), + "DeconvUnary" => { + let deconv_node: &DeconvUnary = match node.op().downcast_ref::() { + Some(b) => b, + None => { + return Err(Box::new(GraphError::OpMismatch(idx, "deconv".to_string()))); + } + }; + + if let Some(dilations) = &deconv_node.pool_spec.dilations { + if dilations.iter().any(|x| *x != 1) { + return Err(Box::new(GraphError::MisformedParams( + "non unit dilations not supported".to_string(), + ))); + } + } + + if (deconv_node.pool_spec.data_format != DataFormat::NCHW) + || (deconv_node.kernel_format != KernelFormat::OIHW) + { + return Err(Box::new(GraphError::MisformedParams( + "data or kernel in wrong format".to_string(), + ))); + } + + let stride = match deconv_node.pool_spec.strides.clone() { + Some(s) => (s[0], s[1]), + None => { + return Err(Box::new(GraphError::MissingParams("strides".to_string()))); + } + }; + let padding = match &deconv_node.pool_spec.padding { + PaddingSpec::Explicit(b, a) => [(b[0], b[1]), (a[0], a[1])], + PaddingSpec::ExplicitOnnxPool(b, a, _) => [(b[0], b[1]), (a[0], a[1])], + _ => { + return Err(Box::new(GraphError::MissingParams("padding".to_string()))); + } + }; + + let kernel = extract_tensor_value(deconv_node.kernel.clone(), symbol_values)?; + let kernel = quantize_tensor(kernel, scales.params, param_visibility)?; + + let bias = match deconv_node.bias.clone() { + Some(b) => { + let const_value = extract_tensor_value(b, symbol_values)?; + + let val = quantize_tensor( + const_value, + scales.params + inputs[0].out_scales()[0], + param_visibility, + )?; + Some(val) + } + None => None, + }; + + let output_padding: (usize, usize) = + (deconv_node.adjustments[0], deconv_node.adjustments[1]); + + SupportedOp::Linear(PolyOp::DeConv { + kernel, + bias, + padding, + output_padding, + stride, + }) + } + "Downsample" => { + let downsample_node: Downsample = match node.op().downcast_ref::() { + Some(b) => b.clone(), + None => { + return Err(Box::new(GraphError::OpMismatch( + idx, + "downsample".to_string(), + ))); + } + }; + + SupportedOp::Linear(PolyOp::Downsample { + axis: downsample_node.axis, + stride: downsample_node.stride as usize, + modulo: downsample_node.modulo, + }) + } + + "Resize" => { + // this is a bit hacky, but we need to extract the resize node somehow + // and this is the only way I can think of + // see https://github.com/sonos/tract/issues/324 + + let resize_node = format!("{:?}", node); + + if !resize_node.contains("interpolator: Nearest") + && !resize_node.contains("nearest: Floor") + { + unimplemented!("Only nearest neighbor interpolation is supported") + } + // check if optional scale factor is present + if inputs.len() != 2 && inputs.len() != 3 { + return Err(Box::new(GraphError::OpMismatch(idx, "Resize".to_string()))); + } + + let scale_factor_node = // find optional_scales_input in the string and extract the value inside the Some + if resize_node.contains("optional_scales_input: None") { + None + } else { + Some(resize_node + .split("optional_scales_input: ") + .collect::>()[1] + .split("Some(") + .collect::>()[1] + .split(')') + .collect::>()[0] + .parse::()?) + }; + + let scale_factor = if let Some(scale_factor_node) = scale_factor_node { + let boxed_op = inputs[scale_factor_node].opkind(); + if let Some(c) = extract_const_raw_values(boxed_op) { + c.map(|x| x as usize).into_iter().collect::>() + } else { + return Err(Box::new(GraphError::OpMismatch(idx, "Resize".to_string()))); + } + } else { + // default + vec![1] + }; + + for i in 1..inputs.len() { + // remove the resize node from the inputs + if let Some(node) = inputs.get_mut(i) { + node.decrement_use(); + deleted_indices.push(i); + } + } + + SupportedOp::Linear(PolyOp::Resize { scale_factor }) + } + + "SumPool" => { + // Extract the padding and stride layer hyperparams + let op = Box::new(node.op()); + let sumpool_node: &SumPool = match op.downcast_ref() { + Some(b) => b, + None => { + return Err(Box::new(GraphError::OpMismatch(idx, "sumpool".to_string()))); + } + }; + + let pool_spec: &PoolSpec = &sumpool_node.pool_spec; + + // only support pytorch type formatting for now + if pool_spec.data_format != DataFormat::NCHW { + return Err(Box::new(GraphError::MissingParams( + "data in wrong format".to_string(), + ))); + } + + let stride = pool_spec + .strides + .clone() + .ok_or(GraphError::MissingParams("stride".to_string()))?; + let padding = match &pool_spec.padding { + PaddingSpec::Explicit(b, a) => [(b[0], b[1]), (a[0], a[1])], + PaddingSpec::ExplicitOnnxPool(b, a, _) => [(b[0], b[1]), (a[0], a[1])], + _ => { + return Err(Box::new(GraphError::MissingParams("padding".to_string()))); + } + }; + let kernel_shape = &pool_spec.kernel_shape; + + let (stride_h, stride_w) = (stride[0], stride[1]); + let (kernel_height, kernel_width) = (kernel_shape[0], kernel_shape[1]); + + SupportedOp::Linear(PolyOp::SumPool { + padding, + stride: (stride_h, stride_w), + kernel_shape: (kernel_height, kernel_width), + }) + } + "GlobalAvgPool" => SupportedOp::Linear(PolyOp::SumPool { + padding: [(0, 0); 2], + stride: (1, 1), + kernel_shape: (inputs[0].out_dims()[0][1], inputs[0].out_dims()[0][2]), + }), + "Pad" => { + let pad_node: &Pad = match node.op().downcast_ref::() { + Some(b) => b, + None => { + return Err(Box::new(GraphError::OpMismatch(idx, "pad".to_string()))); + } + }; + // we only support constant 0 padding + if pad_node.mode + != PadMode::Constant(tract_onnx::prelude::Arc::new( + tract_onnx::prelude::Tensor::zero::(&[])?, + )) + { + return Err(Box::new(GraphError::MisformedParams( + "pad mode or pad type".to_string(), + ))); + } + + let padding_len = pad_node.pads.len(); + + // we only support symmetrical padding that affects the last 2 dims (height and width params) + for (i, pad_params) in pad_node.pads.iter().enumerate() { + if (i < padding_len - 2) && ((pad_params.0 != 0) || (pad_params.1 != 0)) { + return Err(Box::new(GraphError::MisformedParams( + "ezkl currently only supports padding height and width dimensions" + .to_string(), + ))); + } + } + + let padding = [ + ( + pad_node.pads[padding_len - 2].0, + pad_node.pads[padding_len - 1].0, + ), + ( + pad_node.pads[padding_len - 2].1, + pad_node.pads[padding_len - 1].1, + ), + ]; + SupportedOp::Linear(PolyOp::Pad(padding)) + } + "RmAxis" | "Reshape" | "AddAxis" => { + // Extract the slope layer hyperparams + let shapes = node_output_shapes(&node)?; + let mut output_shape = shapes[0] + .as_ref() + .ok_or(GraphError::InvalidDims(idx, "reshape".to_string()))? + .clone(); + if output_shape.is_empty() { + output_shape = vec![1]; + } + + SupportedOp::Linear(PolyOp::Reshape(output_shape)) + } + "Flatten" => { + let new_dims: Vec = vec![inputs[0].out_dims()[0].iter().product::()]; + SupportedOp::Linear(PolyOp::Flatten(new_dims)) + } + c => { + warn!("Unknown op: {}", c); + SupportedOp::Unknown(crate::circuit::ops::Unknown) + } + }; + + Ok((node, deleted_indices)) +} + +/// Extracts the raw values from a [crate::circuit::ops::Constant] op. +pub fn extract_const_raw_values(op: SupportedOp) -> Option> { + match op { + SupportedOp::Constant(crate::circuit::ops::Constant { raw_values, .. }) => Some(raw_values), + _ => None, + } +} + +/// Extracts the quantized values from a [crate::circuit::ops::Constant] op. +pub fn extract_const_quantized_values(op: SupportedOp) -> Option> { + match op { + SupportedOp::Constant(crate::circuit::ops::Constant { + quantized_values, .. + }) => Some(quantized_values), + _ => None, + } +} + +/// Extract the quantized values from a conv op +pub fn extract_conv_values(boxed_op: Box>) -> [Option>; 2] { + let op = boxed_op + .as_any() + .downcast_ref::>(); + + if let Some(PolyOp::Conv { kernel, bias, .. }) = op { + return [Some(kernel.clone()), bias.clone()]; + } + [None, None] +} + +/// Converts a tensor to a [ValTensor] with a given scale. +pub fn quantize_tensor( + const_value: Tensor, + scale: crate::Scale, + visibility: &Visibility, +) -> Result, Box> { + let mut value: Tensor = const_value.par_enum_map(|_, x| { + Ok::<_, TensorError>(crate::fieldutils::i128_to_felt::(quantize_float( + &(x).into(), + 0.0, + scale, + )?)) + })?; + + value.set_scale(scale); + value.set_visibility(visibility); + Ok(value) +} + +use crate::tensor::ValTensor; +/// Split a [ValTensor] into a vector of [ValTensor]s. +pub(crate) fn split_valtensor( + values: &ValTensor, + shapes: Vec>, +) -> Result>, Box> { + let mut tensors: Vec> = Vec::new(); + let mut start = 0; + for shape in shapes { + let end = start + shape.iter().product::(); + let mut tensor = values.get_slice(&[start..end])?; + tensor.reshape(&shape)?; + tensors.push(tensor); + start = end; + } + Ok(tensors) +} + +/// +pub fn homogenize_input_scales( + op: Box>, + input_scales: Vec, + inputs_to_scale: Vec, +) -> Result>, Box> { + let relevant_input_scales = input_scales + .clone() + .into_iter() + .enumerate() + .filter(|(idx, _)| inputs_to_scale.contains(idx)) + .map(|(_, scale)| scale) + .collect_vec(); + + if inputs_to_scale.is_empty() { + return Ok(op); + } + // else if all inputs_scales at inputs_to_scale are the same, we don't need to do anything + if relevant_input_scales.windows(2).all(|w| w[0] == w[1]) { + return Ok(op); + } + + let mut multipliers: Vec = vec![1; input_scales.len()]; + + let max_scale = input_scales.iter().max().ok_or("no max scale")?; + let _ = input_scales + .iter() + .enumerate() + .map(|(idx, input_scale)| { + if !inputs_to_scale.contains(&idx) { + return; + } + let scale_diff = max_scale - input_scale; + if scale_diff > 0 { + let mult = crate::graph::scale_to_multiplier(scale_diff); + multipliers[idx] = mult as u128; + } + }) + .collect_vec(); + + // only rescale if need to + if multipliers.iter().any(|&x| x > 1) { + Ok(Box::new(Rescaled { + inner: Box::new(op.into()), + scale: (0..input_scales.len()).zip(multipliers).collect_vec(), + })) + } else { + Ok(op) + } +} + +#[cfg(test)] +pub mod tests { + + use super::*; + + #[test] + fn test_flatten_valtensors() { + let tensor1: Tensor = (0..10).map(|x| x.into()).into(); + let tensor2: Tensor = (10..20).map(|x| x.into()).into(); + let tensor3: Tensor = (20..30).map(|x| x.into()).into(); + + let mut tensor = Tensor::new(Some(&[tensor1, tensor2, tensor3]), &[3]) + .unwrap() + .combine() + .unwrap(); + + tensor.set_visibility(&Visibility::Public); + + let flattened: ValTensor = tensor.try_into().unwrap(); + + assert_eq!(flattened.len(), 30); + + let split = split_valtensor(&flattened, vec![vec![2, 5], vec![10], vec![5, 2]]).unwrap(); + + assert_eq!(split.len(), 3); + assert_eq!(split[0].len(), 10); + assert_eq!(split[0].dims(), vec![2, 5]); + assert_eq!(split[1].len(), 10); + assert_eq!(split[1].dims(), vec![10]); + assert_eq!(split[2].dims(), vec![5, 2]); + assert_eq!(split[2].len(), 10); + } +} diff --git a/mnist_ezkl/src/graph/vars.rs b/mnist_ezkl/src/graph/vars.rs new file mode 100644 index 0000000..b73a39a --- /dev/null +++ b/mnist_ezkl/src/graph/vars.rs @@ -0,0 +1,451 @@ +use std::error::Error; + +use crate::tensor::TensorType; +use crate::tensor::{ValTensor, VarTensor}; +use crate::RunArgs; +use halo2_proofs::plonk::{Column, ConstraintSystem, Instance}; +use halo2curves::ff::PrimeField; +use itertools::Itertools; +use log::debug; +#[cfg(feature = "python-bindings")] +use pyo3::{ + exceptions::PyValueError, types::PyString, FromPyObject, IntoPy, PyAny, PyObject, PyResult, + PyTryFrom, Python, ToPyObject, +}; + +use serde::{Deserialize, Serialize}; + +use super::*; + +/// Label enum to track whether model input, model parameters, and model output are public, private, or hashed +#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq, PartialOrd, Ord, Default)] +pub enum Visibility { + /// Mark an item as private to the prover (not in the proof submitted for verification) + #[default] + Private, + /// Mark an item as public (sent in the proof submitted for verification) + Public, + /// Mark an item as publicly committed to (hash sent in the proof submitted for verification) + Hashed { + /// Whether the hash is used as an instance (sent in the proof submitted for verification) + /// if false the hash is used as an advice (not in the proof submitted for verification) and is then sent to the computational graph + /// if true the hash is used as an instance (sent in the proof submitted for verification) the *inputs* to the hashing function are then sent to the computational graph + hash_is_public: bool, + /// + outlets: Vec, + }, + /// Mark an item as publicly committed to (KZG commitment sent in the proof submitted for verification) + KZGCommit, + /// Mark an item as encrypted (public key and encrypted message sent in the proof submitted for verificatio) + Encrypted, + /// assigned as a constant in the circuit + Fixed, +} + +impl<'a> From<&'a str> for Visibility { + fn from(s: &'a str) -> Self { + if s.contains("hashed/private") { + // split on last occurence of '/' + let (_, outlets) = s.split_at(s.rfind('/').unwrap()); + let outlets = outlets + .trim_start_matches('/') + .split(',') + .map(|s| s.parse::().unwrap()) + .collect_vec(); + + return Visibility::Hashed { + hash_is_public: false, + outlets, + }; + } + match s { + "private" => Visibility::Private, + "public" => Visibility::Public, + "kzgcommit" => Visibility::KZGCommit, + "fixed" => Visibility::Fixed, + "hashed" | "hashed/public" => Visibility::Hashed { + hash_is_public: true, + outlets: vec![], + }, + "encrypted" => Visibility::Encrypted, + _ => { + log::error!("Invalid value for Visibility: {}", s); + log::warn!("Defaulting to private"); + Visibility::Private + } + } + } +} + +#[cfg(feature = "python-bindings")] +/// Converts Visibility into a PyObject (Required for Visibility to be compatible with Python) +impl IntoPy for Visibility { + fn into_py(self, py: Python) -> PyObject { + match self { + Visibility::Private => "private".to_object(py), + Visibility::Public => "public".to_object(py), + Visibility::Fixed => "fixed".to_object(py), + Visibility::KZGCommit => "kzgcommit".to_object(py), + Visibility::Hashed { + hash_is_public, + outlets, + } => { + if hash_is_public { + "hashed/public".to_object(py) + } else { + let outlets = outlets + .iter() + .map(|o| o.to_string()) + .collect_vec() + .join(","); + format!("hashed/private/{}", outlets).to_object(py) + } + } + Visibility::Encrypted => "encrypted".to_object(py), + } + } +} + +#[cfg(feature = "python-bindings")] +/// Obtains Visibility from PyObject (Required for Visibility to be compatible with Python) +impl<'source> FromPyObject<'source> for Visibility { + fn extract(ob: &'source PyAny) -> PyResult { + let trystr = ::try_from(ob)?; + let strval = trystr.to_string(); + + let strval = strval.as_str(); + + if strval.contains("hashed/private") { + // split on last occurence of '/' + let (_, outlets) = strval.split_at(strval.rfind('/').unwrap()); + let outlets = outlets + .trim_start_matches('/') + .split(',') + .map(|s| s.parse::().unwrap()) + .collect_vec(); + + return Ok(Visibility::Hashed { + hash_is_public: false, + outlets, + }); + } + + match strval.to_lowercase().as_str() { + "private" => Ok(Visibility::Private), + "public" => Ok(Visibility::Public), + "kzgcommit" => Ok(Visibility::KZGCommit), + "hashed" => Ok(Visibility::Hashed { + hash_is_public: true, + outlets: vec![], + }), + "hashed/public" => Ok(Visibility::Hashed { + hash_is_public: true, + outlets: vec![], + }), + "fixed" => Ok(Visibility::Fixed), + "encrypted" => Ok(Visibility::Encrypted), + _ => Err(PyValueError::new_err("Invalid value for Visibility")), + } + } +} + +impl Visibility { + #[allow(missing_docs)] + pub fn is_fixed(&self) -> bool { + matches!(&self, Visibility::Fixed) + } + #[allow(missing_docs)] + pub fn is_private(&self) -> bool { + matches!(&self, Visibility::Private) || self.is_hashed_private() + } + + #[allow(missing_docs)] + pub fn is_public(&self) -> bool { + matches!(&self, Visibility::Public) + } + #[allow(missing_docs)] + pub fn is_hashed(&self) -> bool { + matches!(&self, Visibility::Hashed { .. }) + } + #[allow(missing_docs)] + pub fn is_kzgcommit(&self) -> bool { + matches!(&self, Visibility::KZGCommit) + } + + #[allow(missing_docs)] + pub fn is_hashed_public(&self) -> bool { + if let Visibility::Hashed { + hash_is_public: true, + .. + } = self + { + return true; + } + false + } + #[allow(missing_docs)] + pub fn is_hashed_private(&self) -> bool { + if let Visibility::Hashed { + hash_is_public: false, + .. + } = self + { + return true; + } + false + } + #[allow(missing_docs)] + pub fn is_encrypted(&self) -> bool { + matches!(&self, Visibility::Encrypted) + } + #[allow(missing_docs)] + pub fn requires_processing(&self) -> bool { + matches!(&self, Visibility::Encrypted) + | matches!(&self, Visibility::Hashed { .. }) + | matches!(&self, Visibility::KZGCommit) + } + #[allow(missing_docs)] + pub fn overwrites_inputs(&self) -> Vec { + if let Visibility::Hashed { outlets, .. } = self { + return outlets.clone(); + } + vec![] + } +} +impl std::fmt::Display for Visibility { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + Visibility::KZGCommit => write!(f, "kzgcommit"), + Visibility::Private => write!(f, "private"), + Visibility::Public => write!(f, "public"), + Visibility::Fixed => write!(f, "fixed"), + Visibility::Hashed { .. } => write!(f, "hashed"), + Visibility::Encrypted => write!(f, "encrypted"), + } + } +} + +/// Represents the scale of the model input, model parameters. +#[derive(Clone, Debug, Default, Deserialize, Serialize, PartialEq, PartialOrd)] +pub struct VarScales { + /// + pub input: crate::Scale, + /// + pub params: crate::Scale, + /// + pub rebase_multiplier: u32, +} + +impl std::fmt::Display for VarScales { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "(inputs: {}, params: {})", self.input, self.params) + } +} + +impl VarScales { + /// + pub fn get_max(&self) -> crate::Scale { + std::cmp::max(self.input, self.params) + } + + /// Place in [VarScales] struct. + pub fn from_args(args: &RunArgs) -> Result> { + Ok(Self { + input: args.input_scale, + params: args.param_scale, + rebase_multiplier: args.scale_rebase_multiplier, + }) + } +} + +/// Represents whether the model input, model parameters, and model output are Public or Private to the prover. +#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, PartialOrd)] +pub struct VarVisibility { + /// Input to the model or computational graph + pub input: Visibility, + /// Parameters, such as weights and biases, in the model + pub params: Visibility, + /// Output of the model or computational graph + pub output: Visibility, +} +impl std::fmt::Display for VarVisibility { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!( + f, + "(inputs: {}, params: {}, outputs: {})", + self.input, self.params, self.output + ) + } +} + +impl Default for VarVisibility { + fn default() -> Self { + Self { + input: Visibility::Private, + params: Visibility::Private, + output: Visibility::Public, + } + } +} + +impl VarVisibility { + /// Read from cli args whether the model input, model parameters, and model output are Public or Private to the prover. + /// Place in [VarVisibility] struct. + pub fn from_args(args: &RunArgs) -> Result> { + let input_vis = &args.input_visibility; + let params_vis = &args.param_visibility; + let output_vis = &args.output_visibility; + + if params_vis.is_public() { + return Err( + "public visibility for params is deprecated, please use `fixed` instead".into(), + ); + } + + if !output_vis.is_public() + & !params_vis.is_public() + & !input_vis.is_public() + & !output_vis.is_fixed() + & !params_vis.is_fixed() + & !input_vis.is_fixed() + & !output_vis.is_hashed() + & !params_vis.is_hashed() + & !input_vis.is_hashed() + & !output_vis.is_encrypted() + & !params_vis.is_encrypted() + & !input_vis.is_encrypted() + & !output_vis.is_kzgcommit() + & !params_vis.is_kzgcommit() + & !input_vis.is_kzgcommit() + { + return Err(Box::new(GraphError::Visibility)); + } + Ok(Self { + input: input_vis.clone(), + params: params_vis.clone(), + output: output_vis.clone(), + }) + } +} + +/// A wrapper for holding all columns that will be assigned to by a model. +#[derive(Clone, Debug)] +pub struct ModelVars { + #[allow(missing_docs)] + pub advices: Vec, + #[allow(missing_docs)] + pub instance: Option>, +} + +impl ModelVars { + /// Get instance col + pub fn get_instance_col(&self) -> Option<&Column> { + if let Some(instance) = &self.instance { + match instance { + ValTensor::Instance { inner, .. } => Some(inner), + _ => None, + } + } else { + None + } + } + + /// Set the initial instance offset + pub fn set_initial_instance_offset(&mut self, offset: usize) { + if let Some(instance) = &mut self.instance { + instance.set_initial_instance_offset(offset); + } + } + + /// Get the total instance len + pub fn get_instance_len(&self) -> usize { + if let Some(instance) = &self.instance { + instance.get_total_instance_len() + } else { + 0 + } + } + + /// Increment the instance offset + pub fn increment_instance_idx(&mut self) { + if let Some(instance) = &mut self.instance { + instance.increment_idx(); + } + } + + /// Reset the instance offset + pub fn set_instance_idx(&mut self, val: usize) { + if let Some(instance) = &mut self.instance { + instance.set_idx(val); + } + } + + /// Get the instance offset + pub fn get_instance_idx(&self) -> usize { + if let Some(instance) = &self.instance { + instance.get_idx() + } else { + 0 + } + } + + /// + pub fn instantiate_instance( + &mut self, + cs: &mut ConstraintSystem, + instance_dims: Vec>, + instance_scale: crate::Scale, + existing_instance: Option>, + ) { + debug!("model uses {:?} instance dims", instance_dims); + self.instance = if let Some(existing_instance) = existing_instance { + debug!("using existing instance"); + Some(ValTensor::new_instance_from_col( + instance_dims, + instance_scale, + existing_instance, + )) + } else { + Some(ValTensor::new_instance(cs, instance_dims, instance_scale)) + }; + } + + /// Allocate all columns that will be assigned to by a model. + pub fn new( + cs: &mut ConstraintSystem, + logrows: usize, + var_len: usize, + num_inner_cols: usize, + num_constants: usize, + module_requires_fixed: bool, + ) -> Self { + info!("number of blinding factors: {}", cs.blinding_factors()); + + let advices = (0..3) + .map(|_| VarTensor::new_advice(cs, logrows, num_inner_cols, var_len)) + .collect_vec(); + + debug!( + "model uses {} advice blocks (size={})", + advices.iter().map(|v| v.num_blocks()).sum::(), + num_inner_cols + ); + + let num_const_cols = + VarTensor::constant_cols(cs, logrows, num_constants, module_requires_fixed); + debug!("model uses {} fixed columns", num_const_cols); + + ModelVars { + advices, + instance: None, + } + } + + /// Allocate all columns that will be assigned to by a model. + pub fn new_dummy() -> Self { + ModelVars { + advices: vec![], + instance: None, + } + } +} diff --git a/mnist_ezkl/src/hub.rs b/mnist_ezkl/src/hub.rs new file mode 100644 index 0000000..f6fe892 --- /dev/null +++ b/mnist_ezkl/src/hub.rs @@ -0,0 +1,146 @@ +use serde::{Deserialize, Serialize}; + +/// Stores users organizations +#[derive(Serialize, Deserialize, Debug)] +pub struct Organization { + /// The organization id + pub id: String, + /// The users username + pub name: String, +} + +impl Organization { + /// Export the organization as json + pub fn as_json(&self) -> Result> { + let serialized = match serde_json::to_string(&self) { + Ok(s) => s, + Err(e) => { + return Err(Box::new(e)); + } + }; + Ok(serialized) + } +} + +/// Stores Organization +#[derive(Serialize, Deserialize, Debug)] +pub struct Organizations { + /// An Array of Organizations + pub organizations: Vec, +} + +impl Organizations { + /// Export the organizations as json + pub fn as_json(&self) -> Result> { + let serialized = match serde_json::to_string(&self) { + Ok(s) => s, + Err(e) => { + return Err(Box::new(e)); + } + }; + Ok(serialized) + } +} + +/// Stores the Proof Response +#[derive(Debug, Deserialize, Serialize)] +pub struct Proof { + /// stores the artifact + pub artifact: Option, + /// stores the Proof Id + pub id: String, + /// stores the instances + pub instances: Option>, + /// stores the proofs + pub proof: Option, + /// stores the status + pub status: Option, + /// stores the transcript type + #[serde(rename = "transcriptType")] + pub transcript_type: Option, +} + +impl Proof { + /// Export the proof as json + pub fn as_json(&self) -> Result> { + let serialized = match serde_json::to_string(&self) { + Ok(s) => s, + Err(e) => { + return Err(Box::new(e)); + } + }; + Ok(serialized) + } +} + +/// Stores the Artifacts +#[derive(Debug, Deserialize, Serialize)] +pub struct Artifact { + /// stores the aritfact id + pub id: Option, + /// stores the name of the artifact + pub name: Option, + /// stores the status of the artifact + pub status: Option, + /// stores errors while producing the artifact + pub errors: Option, +} + +impl Artifact { + /// Export the artifact as json + pub fn as_json(&self) -> Result> { + let serialized = match serde_json::to_string(&self) { + Ok(s) => s, + Err(e) => { + return Err(Box::new(e)); + } + }; + Ok(serialized) + } +} + +#[cfg(feature = "python-bindings")] +impl pyo3::ToPyObject for Artifact { + fn to_object(&self, py: pyo3::Python) -> pyo3::PyObject { + let dict = pyo3::types::PyDict::new(py); + dict.set_item("id", &self.id).unwrap(); + dict.set_item("name", &self.name).unwrap(); + dict.set_item("status", &self.status).unwrap(); + dict.set_item("errors", &self.errors).unwrap(); + dict.into() + } +} + +#[cfg(feature = "python-bindings")] +impl pyo3::ToPyObject for Proof { + fn to_object(&self, py: pyo3::Python) -> pyo3::PyObject { + let dict = pyo3::types::PyDict::new(py); + dict.set_item("artifact", &self.artifact).unwrap(); + dict.set_item("id", &self.id).unwrap(); + dict.set_item("instances", &self.instances).unwrap(); + dict.set_item("proof", &self.proof).unwrap(); + dict.set_item("status", &self.status).unwrap(); + dict.set_item("transcript_type", &self.transcript_type) + .unwrap(); + dict.into() + } +} + +#[cfg(feature = "python-bindings")] +impl pyo3::ToPyObject for Organizations { + fn to_object(&self, py: pyo3::Python) -> pyo3::PyObject { + let dict = pyo3::types::PyDict::new(py); + dict.set_item("organizations", &self.organizations).unwrap(); + dict.into() + } +} + +#[cfg(feature = "python-bindings")] +impl pyo3::ToPyObject for Organization { + fn to_object(&self, py: pyo3::Python) -> pyo3::PyObject { + let dict = pyo3::types::PyDict::new(py); + dict.set_item("id", &self.id).unwrap(); + dict.set_item("name", &self.name).unwrap(); + dict.into() + } +} diff --git a/mnist_ezkl/src/lib.rs b/mnist_ezkl/src/lib.rs new file mode 100644 index 0000000..8fd0f2e --- /dev/null +++ b/mnist_ezkl/src/lib.rs @@ -0,0 +1,184 @@ +#![deny( + bad_style, + dead_code, + improper_ctypes, + non_shorthand_field_patterns, + no_mangle_generic_items, + overflowing_literals, + path_statements, + patterns_in_fns_without_body, + private_in_public, + unconditional_recursion, + unused, + unused_allocation, + unused_comparisons, + unused_parens, + while_true, + missing_docs, + trivial_casts, + trivial_numeric_casts, + unused_extern_crates, + unused_import_braces, + missing_debug_implementations, + unsafe_code +)] +// we allow this for our dynamic range based indexing scheme +#![allow(clippy::single_range_in_vec_init)] +#![feature(round_ties_even)] + +//! A library for turning computational graphs, such as neural networks, into ZK-circuits. +//! + +use circuit::Tolerance; +use clap::Args; +use graph::Visibility; +use serde::{Deserialize, Serialize}; + +/// Methods for configuring tensor operations and assigning values to them in a Halo2 circuit. +pub mod circuit; +/// CLI commands. +#[cfg(not(target_arch = "wasm32"))] +pub mod commands; +#[cfg(not(target_arch = "wasm32"))] +// abigen doesn't generate docs for this module +#[allow(missing_docs)] +/// Utility functions for contracts +pub mod eth; +/// Command execution +/// +#[cfg(not(target_arch = "wasm32"))] +pub mod execute; +/// Utilities for converting from Halo2 Field types to integers (and vice-versa). +pub mod fieldutils; +/// Methods for loading onnx format models and automatically laying them out in +/// a Halo2 circuit. +#[cfg(feature = "onnx")] +pub mod graph; +/// Methods for deploying and interacting with the ezkl hub +#[cfg(not(target_arch = "wasm32"))] +pub mod hub; +/// beautiful logging +#[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))] +pub mod logger; +/// Tools for proofs and verification used by cli +pub mod pfsys; +/// Python bindings +#[cfg(feature = "python-bindings")] +pub mod python; +/// An implementation of multi-dimensional tensors. +pub mod tensor; +/// wasm prover and verifier +#[cfg(all(target_arch = "wasm32", target_os = "unknown"))] +pub mod wasm; + +/// The denominator in the fixed point representation used when quantizing inputs +pub type Scale = i32; + +/// Parameters specific to a proving run +#[derive(Debug, Args, Deserialize, Serialize, Clone, Default, PartialEq, PartialOrd)] +pub struct RunArgs { + /// The tolerance for error on model outputs + #[arg(short = 'T', long, default_value = "0")] + pub tolerance: Tolerance, + /// The denominator in the fixed point representation used when quantizing inputs + #[arg(short = 'S', long, default_value = "7", allow_hyphen_values = true)] + pub input_scale: Scale, + /// The denominator in the fixed point representation used when quantizing parameters + #[arg(long, default_value = "7", allow_hyphen_values = true)] + pub param_scale: Scale, + /// if the scale is ever > scale_rebase_multiplier * input_scale then the scale is rebased to input_scale (this a more advanced parameter, use with caution) + #[arg(long, default_value = "1")] + pub scale_rebase_multiplier: u32, + /// The min and max elements in the lookup table input column + #[arg(short = 'B', long, value_parser = parse_tuple::, default_value = "(-32768,32768)")] + pub lookup_range: (i128, i128), + /// The log_2 number of rows + #[arg(short = 'K', long, default_value = "17")] + pub logrows: u32, + /// The log_2 number of rows + #[arg(short = 'N', long, default_value = "2")] + pub num_inner_cols: usize, + /// Hand-written parser for graph variables, eg. batch_size=1 + #[arg(short = 'V', long, value_parser = parse_key_val::, default_value = "batch_size=1", value_delimiter = ',')] + pub variables: Vec<(String, usize)>, + /// Flags whether inputs are public, private, hashed + #[arg(long, default_value = "private")] + pub input_visibility: Visibility, + /// Flags whether outputs are public, private, hashed + #[arg(long, default_value = "public")] + pub output_visibility: Visibility, + /// Flags whether params are public, private, hashed + #[arg(long, default_value = "private")] + pub param_visibility: Visibility, +} + +impl RunArgs { + /// + pub fn validate(&self) -> Result<(), Box> { + if self.scale_rebase_multiplier < 1 { + return Err("scale_rebase_multiplier must be >= 1".into()); + } + if self.lookup_range.0 > self.lookup_range.1 { + return Err("lookup_range min is greater than max".into()); + } + if self.logrows < 1 { + return Err("logrows must be >= 1".into()); + } + if self.num_inner_cols < 1 { + return Err("num_inner_cols must be >= 1".into()); + } + Ok(()) + } + + /// Export the ezkl configuration as json + pub fn as_json(&self) -> Result> { + let serialized = match serde_json::to_string(&self) { + Ok(s) => s, + Err(e) => { + return Err(Box::new(e)); + } + }; + Ok(serialized) + } + /// Parse an ezkl configuration from a json + pub fn from_json(arg_json: &str) -> Result { + serde_json::from_str(arg_json) + } +} + +/// Parse a single key-value pair +fn parse_key_val( + s: &str, +) -> Result<(T, U), Box> +where + T: std::str::FromStr, + T::Err: std::error::Error + Send + Sync + 'static, + U: std::str::FromStr, + U::Err: std::error::Error + Send + Sync + 'static, +{ + let pos = s + .find('=') + .ok_or_else(|| format!("invalid KEY=value: no `=` found in `{s}`"))?; + Ok((s[..pos].parse()?, s[pos + 1..].parse()?)) +} + +/// Parse a tuple +fn parse_tuple(s: &str) -> Result<(T, T), Box> +where + T: std::str::FromStr + Clone, + T::Err: std::error::Error + Send + Sync + 'static, +{ + let res = s.trim_matches(|p| p == '(' || p == ')').split(','); + + let res = res + .map(|x| { + // remove blank space + let x = x.trim(); + x.parse::() + }) + .collect::, _>>()?; + if res.len() != 2 { + return Err("invalid tuple".into()); + } + Ok((res[0].clone(), res[1].clone())) +} diff --git a/mnist_ezkl/src/logger.rs b/mnist_ezkl/src/logger.rs new file mode 100644 index 0000000..73b3888 --- /dev/null +++ b/mnist_ezkl/src/logger.rs @@ -0,0 +1,91 @@ +use colored::*; +use env_logger::Builder; +use instant::Instant; +use log::{Level, LevelFilter, Record}; +use std::env; +use std::fmt::Formatter; +use std::io::Write; + +/// sets the log level color +#[allow(dead_code)] +pub fn level_color(level: &log::Level, msg: &str) -> String { + match level { + Level::Error => msg.red(), + Level::Warn => msg.yellow(), + Level::Info => msg.blue(), + Level::Debug => msg.green(), + Level::Trace => msg.magenta(), + } + .bold() + .to_string() +} + +/// sets the log level text color +pub fn level_text_color(level: &log::Level, msg: &str) -> String { + match level { + Level::Error => msg.red(), + Level::Warn => msg.yellow(), + Level::Info => msg.white(), + Level::Debug => msg.white(), + Level::Trace => msg.white(), + } + .bold() + .to_string() +} + +/// sets the log level token +fn level_token(level: &Level) -> &str { + match *level { + Level::Error => "E", + Level::Warn => "W", + Level::Info => "*", + Level::Debug => "D", + Level::Trace => "T", + } +} + +/// sets the log level prefix token +fn prefix_token(level: &Level) -> String { + format!( + "{}{}{}", + "[".blue().bold(), + level_color(level, level_token(level)), + "]".blue().bold() + ) +} + +/// formats the log +pub fn format(buf: &mut Formatter, record: &Record<'_>) -> Result<(), std::fmt::Error> { + let sep = format!("\n{} ", " | ".white().bold()); + let level = record.level(); + writeln!( + buf, + "{} {}", + prefix_token(&level), + level_color(&level, record.args().as_str().unwrap()).replace('\n', &sep), + ) +} + +/// initializes the logger +pub fn init_logger() { + let start = Instant::now(); + let mut builder = Builder::new(); + + builder.format(move |buf, record| { + writeln!( + buf, + "{} [{}s, {}] - {}", + prefix_token(&record.level()), + start.elapsed().as_secs(), + record.metadata().target(), + level_text_color(&record.level(), &format!("{}", record.args())) + .replace('\n', &format!("\n{} ", " | ".white().bold())) + ) + }); + builder.target(env_logger::Target::Stdout); + builder.filter(None, LevelFilter::Info); + if env::var("RUST_LOG").is_ok() { + builder.parse_filters(&env::var("RUST_LOG").unwrap()); + } + builder.init(); +} diff --git a/mnist_ezkl/src/pfsys/evm/aggregation.rs b/mnist_ezkl/src/pfsys/evm/aggregation.rs new file mode 100644 index 0000000..6f9c794 --- /dev/null +++ b/mnist_ezkl/src/pfsys/evm/aggregation.rs @@ -0,0 +1,420 @@ +use crate::pfsys::{Snark, SnarkWitness}; +use halo2_proofs::circuit::AssignedCell; +use halo2_proofs::plonk::{self}; +use halo2_proofs::{ + circuit::{Layouter, SimpleFloorPlanner, Value}, + plonk::{Circuit, ConstraintSystem}, +}; +use halo2_wrong_ecc::{ + integer::rns::Rns, + maingate::{ + MainGate, MainGateConfig, MainGateInstructions, RangeChip, RangeConfig, RangeInstructions, + RegionCtx, + }, + EccConfig, +}; +use halo2curves::bn256::{Bn256, Fq, Fr, G1Affine}; +use halo2curves::ff::PrimeField; +use itertools::Itertools; +use log::trace; +use rand::rngs::OsRng; +use snark_verifier::loader::native::NativeLoader; +use snark_verifier::loader::EcPointLoader; +use snark_verifier::{ + loader, + pcs::{ + kzg::{ + Bdfg21, KzgAccumulator, KzgAs, KzgSuccinctVerifyingKey, LimbsEncoding, + LimbsEncodingInstructions, + }, + AccumulationScheme, AccumulationSchemeProver, + }, + system, + util::arithmetic::fe_to_limbs, + verifier::{self, SnarkVerifier}, +}; +use std::rc::Rc; +use thiserror::Error; + +const LIMBS: usize = 4; +const BITS: usize = 68; +type As = KzgAs; +/// Type for aggregator verification +type PlonkSuccinctVerifier = verifier::plonk::PlonkSuccinctVerifier>; + +const T: usize = 5; +const RATE: usize = 4; +const R_F: usize = 8; +const R_P: usize = 60; + +type Svk = KzgSuccinctVerifyingKey; +type BaseFieldEccChip = halo2_wrong_ecc::BaseFieldEccChip; +/// The loader type used in the transcript definition +type Halo2Loader<'a> = loader::halo2::Halo2Loader<'a, G1Affine, BaseFieldEccChip>; +/// Application snark transcript +pub type PoseidonTranscript = + system::halo2::transcript::halo2::PoseidonTranscript; + +#[derive(Error, Debug)] +/// Errors related to proof aggregation +pub enum AggregationError { + /// A KZG proof could not be verified + #[error("failed to verify KZG proof")] + KZGProofVerification, + /// proof read errors + #[error("Failed to read proof")] + ProofRead, + /// proof verification errors + #[error("Failed to verify proof")] + ProofVerify, + /// proof creation errors + #[error("Failed to create proof")] + ProofCreate, +} + +type AggregationResult<'a> = ( + // accumulator + KzgAccumulator>>, + // the set of assigned cells + Vec>>, +); + +type LoadedProof<'a> = verifier::plonk::PlonkProof< + G1Affine, + Rc< + loader::halo2::Halo2Loader< + 'a, + G1Affine, + halo2_wrong_ecc::BaseFieldEccChip, + >, + >, + KzgAs, +>; + +/// Aggregate one or more application snarks of the same shape into a KzgAccumulator +pub fn aggregate<'a>( + svk: &Svk, + loader: &Rc>, + snarks: &[SnarkWitness], + as_proof: Value<&'_ [u8]>, + split_proofs: bool, +) -> Result, plonk::Error> { + let assign_instances = |instances: &[Vec>]| { + instances + .iter() + .map(|instances| { + instances + .iter() + .map(|instance| loader.assign_scalar(*instance)) + .collect_vec() + }) + .collect_vec() + }; + + let mut accumulators = vec![]; + let mut snark_instances = vec![]; + let mut proofs: Vec> = vec![]; + + for snark in snarks.iter() { + let protocol = snark.protocol.as_ref().unwrap().loaded(loader); + let instances = assign_instances(&snark.instances); + + // get assigned cells + snark_instances.extend(instances.iter().map(|instance| { + instance + .iter() + .map(|v| v.clone().into_assigned()) + .collect_vec() + })); + + // loader.ctx().constrain_equal(cell_0, cell_1) + let mut transcript = PoseidonTranscript::, _>::new(loader, snark.proof()); + let proof = PlonkSuccinctVerifier::read_proof(svk, &protocol, &instances, &mut transcript) + .map_err(|_| plonk::Error::Synthesis)?; + + if split_proofs { + let previous_proof = proofs.last(); + let split_commit = match snark.clone().split { + Some(split) => split, + None => { + log::error!("Failed to split KZG commit for sequential proofs"); + return Err(plonk::Error::Synthesis); + } + }; + if let Some(previous_proof) = previous_proof { + // output of previous proof + let output = &previous_proof.witnesses[split_commit.start..split_commit.end]; + // input of current proof + let split_commit_len = split_commit.end - split_commit.start; + let input = &proof.witnesses[..split_commit_len]; + // these points were already assigned previously when loading the transcript so this is safe + // and equivalent to a copy constraint and an equality constraint + for (output, input) in output.iter().zip(input.iter()) { + loader + .ec_point_assert_eq("assert commits match", output, input) + .map_err(|e| { + log::error!( + "Failed to match KZG commits for sequential proofs: {:?}", + e + ); + plonk::Error::Synthesis + })?; + } + } + proofs.push(proof.clone()); + } + + let mut accum = PlonkSuccinctVerifier::verify(svk, &protocol, &instances, &proof) + .map_err(|_| plonk::Error::Synthesis)?; + accumulators.append(&mut accum); + } + let accumulator = { + let mut transcript = PoseidonTranscript::, _>::new(loader, as_proof); + let proof = As::read_proof(&Default::default(), &accumulators, &mut transcript).unwrap(); + As::verify(&Default::default(), &accumulators, &proof).map_err(|_| plonk::Error::Synthesis) + }?; + Ok((accumulator, snark_instances)) +} + +/// The Halo2 Config for the aggregation circuit +#[derive(Clone, Debug)] +pub struct AggregationConfig { + main_gate_config: MainGateConfig, + range_config: RangeConfig, +} + +impl AggregationConfig { + /// Configure the aggregation circuit + pub fn configure( + meta: &mut ConstraintSystem, + composition_bits: Vec, + overflow_bits: Vec, + ) -> Self { + let main_gate_config = MainGate::::configure(meta); + let range_config = + RangeChip::::configure(meta, &main_gate_config, composition_bits, overflow_bits); + AggregationConfig { + main_gate_config, + range_config, + } + } + + /// Create a MainGate from the aggregation approach + pub fn main_gate(&self) -> MainGate { + MainGate::new(self.main_gate_config.clone()) + } + + /// Create a range chip to decompose and range check inputs + pub fn range_chip(&self) -> RangeChip { + RangeChip::new(self.range_config.clone()) + } + + /// Create an ecc chip for ec ops + pub fn ecc_chip(&self) -> BaseFieldEccChip { + BaseFieldEccChip::new(EccConfig::new( + self.range_config.clone(), + self.main_gate_config.clone(), + )) + } +} + +/// Aggregation Circuit with a SuccinctVerifyingKey, application snark witnesses (each with a proof and instance variables), and the instance variables and the resulting aggregation circuit proof. +#[derive(Clone, Debug)] +pub struct AggregationCircuit { + svk: Svk, + snarks: Vec>, + instances: Vec, + as_proof: Value>, + split_proof: bool, +} + +impl AggregationCircuit { + /// Create a new Aggregation Circuit with a SuccinctVerifyingKey, application snark witnesses (each with a proof and instance variables), and the instance variables and the resulting aggregation circuit proof. + pub fn new( + svk: &KzgSuccinctVerifyingKey, + snarks: impl IntoIterator>, + split_proof: bool, + ) -> Result { + let snarks = snarks.into_iter().collect_vec(); + + let mut accumulators = vec![]; + + for snark in snarks.iter() { + trace!("Aggregating with snark instances {:?}", snark.instances); + let mut transcript = PoseidonTranscript::::new(snark.proof.as_slice()); + let proof = PlonkSuccinctVerifier::read_proof( + svk, + snark.protocol.as_ref().unwrap(), + &snark.instances, + &mut transcript, + ) + .map_err(|e| { + log::error!("{:?}", e); + AggregationError::ProofRead + })?; + let mut accum = PlonkSuccinctVerifier::verify( + svk, + snark.protocol.as_ref().unwrap(), + &snark.instances, + &proof, + ) + .map_err(|_| AggregationError::ProofVerify)?; + accumulators.append(&mut accum); + } + + trace!("Accumulator"); + let (accumulator, as_proof) = { + let mut transcript = PoseidonTranscript::::new(Vec::new()); + let accumulator = + As::create_proof(&Default::default(), &accumulators, &mut transcript, OsRng) + .map_err(|_| AggregationError::ProofCreate)?; + (accumulator, transcript.finalize()) + }; + + trace!("KzgAccumulator"); + let KzgAccumulator { lhs, rhs } = accumulator; + let instances = [lhs.x, lhs.y, rhs.x, rhs.y] + .map(fe_to_limbs::<_, _, LIMBS, BITS>) + .concat(); + + Ok(Self { + svk: *svk, + snarks: snarks.into_iter().map_into().collect(), + instances, + as_proof: Value::known(as_proof), + split_proof, + }) + } + + /// + pub fn num_limbs() -> usize { + LIMBS + } + /// + pub fn num_bits() -> usize { + BITS + } + + /// Accumulator indices used in generating verifier. + pub fn accumulator_indices() -> Vec<(usize, usize)> { + (0..4 * LIMBS).map(|idx| (0, idx)).collect() + } + + /// Number of instance variables for the aggregation circuit, used in generating verifier. + pub fn num_instance(orginal_circuit_instances: usize) -> Vec { + let accumulation_instances = 4 * LIMBS; + vec![accumulation_instances + orginal_circuit_instances] + } + + /// Instance variables for the aggregation circuit, fed to verifier. + pub fn instances(&self) -> Vec { + // also get snark instances here + let mut snark_instances: Vec>>> = self + .snarks + .iter() + .map(|snark| snark.instances.clone()) + .collect_vec(); + + // reduce from Vec>>> to Vec>> + let mut instances: Vec = self.instances.clone(); + for snark_instance in snark_instances.iter_mut() { + for instance in snark_instance.iter_mut() { + let mut felt_evals = vec![]; + for value in instance.iter_mut() { + value.map(|v| felt_evals.push(v)); + } + instances.extend(felt_evals); + } + } + + instances + } + + fn as_proof(&self) -> Value<&[u8]> { + self.as_proof.as_ref().map(Vec::as_slice) + } +} + +impl Circuit for AggregationCircuit { + type Config = AggregationConfig; + type FloorPlanner = SimpleFloorPlanner; + type Params = (); + + fn without_witnesses(&self) -> Self { + Self { + svk: self.svk, + snarks: self + .snarks + .iter() + .map(SnarkWitness::without_witnesses) + .collect(), + instances: Vec::new(), + as_proof: Value::unknown(), + split_proof: self.split_proof, + } + } + + fn configure(meta: &mut ConstraintSystem) -> Self::Config { + AggregationConfig::configure( + meta, + vec![BITS / LIMBS], + Rns::::construct().overflow_lengths(), + ) + } + + fn synthesize( + &self, + config: Self::Config, + mut layouter: impl Layouter, + ) -> Result<(), plonk::Error> { + let main_gate = config.main_gate(); + let range_chip = config.range_chip(); + + range_chip.load_table(&mut layouter)?; + + let (accumulator_limbs, snark_instances) = layouter.assign_region( + || "", + |region| { + let ctx = RegionCtx::new(region, 0); + + let ecc_chip = config.ecc_chip(); + let loader = Halo2Loader::new(ecc_chip, ctx); + let (accumulator, snark_instances) = aggregate( + &self.svk, + &loader, + &self.snarks, + self.as_proof(), + self.split_proof, + )?; + + let accumulator_limbs = [accumulator.lhs, accumulator.rhs] + .iter() + .map(|ec_point| { + loader + .ecc_chip() + .assign_ec_point_to_limbs(&mut loader.ctx_mut(), ec_point.assigned()) + }) + .collect::, plonk::Error>>()? + .into_iter() + .flatten(); + + Ok((accumulator_limbs, snark_instances)) + }, + )?; + + let mut instance_offset = 0; + for limb in accumulator_limbs { + main_gate.expose_public(layouter.namespace(|| ""), limb, instance_offset)?; + instance_offset += 1; + } + + for instance in snark_instances.into_iter() { + for elem in instance.into_iter() { + main_gate.expose_public(layouter.namespace(|| ""), elem, instance_offset)?; + instance_offset += 1; + } + } + + Ok(()) + } +} diff --git a/mnist_ezkl/src/pfsys/evm/mod.rs b/mnist_ezkl/src/pfsys/evm/mod.rs new file mode 100644 index 0000000..3d3d916 --- /dev/null +++ b/mnist_ezkl/src/pfsys/evm/mod.rs @@ -0,0 +1,31 @@ +use thiserror::Error; + +/// Aggregate proof generation for EVM +pub mod aggregation; +/// Simple (single) proof generation for EVM +pub mod single; + +#[derive(Error, Debug)] +/// Errors related to evm verification +pub enum EvmVerificationError { + /// If the Solidity verifier worked but returned false + #[error("Solidity verifier found the proof invalid")] + InvalidProof, + /// If the Solidity verifier threw and error (e.g. OutOfGas) + #[error("Execution of Solidity code failed")] + SolidityExecution, + /// EVM execution errors + #[error("EVM execution of raw code failed")] + RawExecution, + /// EVM verify errors + #[error("evm verification reverted")] + Reverted, + /// EVM verify errors + #[error("evm deployment failed")] + Deploy, + /// Invalid Visibilit + #[error("Invalid visibility")] + InvalidVisibility, +} +/// YulCode type which is just an alias of string +pub type YulCode = String; diff --git a/mnist_ezkl/src/pfsys/evm/single.rs b/mnist_ezkl/src/pfsys/evm/single.rs new file mode 100644 index 0000000..2b71b5f --- /dev/null +++ b/mnist_ezkl/src/pfsys/evm/single.rs @@ -0,0 +1,53 @@ +use crate::pfsys::evm::YulCode; +use halo2_proofs::poly::commitment::ParamsProver; +use halo2_proofs::{plonk::VerifyingKey, poly::kzg::commitment::ParamsKZG}; +use halo2curves::bn256::{Bn256, Fq, Fr, G1Affine}; +use snark_verifier::{ + loader::evm::EvmLoader, + pcs::kzg::{Gwc19, KzgAs}, + system::halo2::{compile, transcript::evm::EvmTranscript, Config}, + verifier::{self, SnarkVerifier}, +}; +use std::rc::Rc; +use thiserror::Error; + +type PlonkVerifier = verifier::plonk::PlonkVerifier>; + +#[derive(Error, Debug)] +/// Errors related to simple evm verifier generation +pub enum SimpleError { + /// proof read errors + #[error("Failed to read proof")] + ProofRead, + /// proof verification errors + #[error("Failed to verify proof")] + ProofVerify, +} + +/// Create EVM verifier yulcode +pub fn gen_evm_verifier( + params: &ParamsKZG, + vk: &VerifyingKey, + num_instance: usize, +) -> Result { + let protocol = compile( + params, + vk, + Config::kzg().with_num_instance(vec![num_instance]), + ); + let vk = (params.get_g()[0], params.g2(), params.s_g2()).into(); + + let loader = EvmLoader::new::(); + let protocol = protocol.loaded(&loader); + let mut transcript = EvmTranscript::<_, Rc, _, _>::new(&loader); + + let instances = transcript.load_instances(vec![num_instance]); + let proof = PlonkVerifier::read_proof(&vk, &protocol, &instances, &mut transcript) + .map_err(|_| SimpleError::ProofRead)?; + PlonkVerifier::verify(&vk, &protocol, &instances, &proof) + .map_err(|_| SimpleError::ProofVerify)?; + + let yul_code = &loader.yul_code(); + + Ok(yul_code.clone()) +} diff --git a/mnist_ezkl/src/pfsys/mod.rs b/mnist_ezkl/src/pfsys/mod.rs new file mode 100644 index 0000000..943faf5 --- /dev/null +++ b/mnist_ezkl/src/pfsys/mod.rs @@ -0,0 +1,866 @@ +/// EVM related proving and verification +pub mod evm; + +/// SRS generation, processing, verification and downloading +pub mod srs; + +use crate::circuit::CheckMode; +use crate::graph::GraphWitness; +use crate::pfsys::evm::aggregation::PoseidonTranscript; +use crate::tensor::TensorType; +use clap::ValueEnum; +use halo2_proofs::circuit::Value; +use halo2_proofs::plonk::{ + create_proof, keygen_pk, keygen_vk, verify_proof, Circuit, ProvingKey, VerifyingKey, +}; +use halo2_proofs::poly::commitment::{CommitmentScheme, Params, ParamsProver, Prover, Verifier}; +use halo2_proofs::poly::kzg::commitment::{KZGCommitmentScheme, ParamsKZG}; +use halo2_proofs::poly::kzg::multiopen::{ProverSHPLONK, VerifierSHPLONK}; +use halo2_proofs::poly::VerificationStrategy; +use halo2_proofs::transcript::{EncodedChallenge, TranscriptReadBuffer, TranscriptWriterBuffer}; +use halo2curves::ff::{FromUniformBytes, PrimeField, WithSmallOrderMulGroup}; +use halo2curves::serde::SerdeObject; +use halo2curves::CurveAffine; +use instant::Instant; +use log::{debug, info, trace}; +#[cfg(not(feature = "det-prove"))] +use rand::rngs::OsRng; +#[cfg(feature = "det-prove")] +use rand::rngs::StdRng; +use serde::de::DeserializeOwned; +use serde::{Deserialize, Serialize}; +use snark_verifier::loader::native::NativeLoader; +use snark_verifier::system::halo2::transcript::evm::EvmTranscript; +use snark_verifier::system::halo2::{compile, Config}; +use snark_verifier::verifier::plonk::PlonkProtocol; +use std::error::Error; +use std::fs::File; +use std::io::{self, BufReader, BufWriter, Cursor, Write}; +use std::ops::Deref; +use std::path::PathBuf; +use thiserror::Error as thisError; + +use halo2curves::bn256::{Bn256, Fr, G1Affine}; + +#[allow(missing_docs)] +#[derive(ValueEnum, Copy, Clone, Debug, PartialEq, Eq, Deserialize, Serialize, PartialOrd)] +pub enum ProofType { + Single, + ForAggr, +} + +impl From for TranscriptType { + fn from(val: ProofType) -> Self { + match val { + ProofType::Single => TranscriptType::EVM, + ProofType::ForAggr => TranscriptType::Poseidon, + } + } +} + +impl From for StrategyType { + fn from(val: ProofType) -> Self { + match val { + ProofType::Single => StrategyType::Single, + ProofType::ForAggr => StrategyType::Accum, + } + } +} + +#[cfg(feature = "python-bindings")] +impl ToPyObject for ProofType { + fn to_object(&self, py: Python) -> PyObject { + match self { + ProofType::Single => "Single".to_object(py), + ProofType::ForAggr => "ForAggr".to_object(py), + } + } +} + +#[cfg(feature = "python-bindings")] +/// Obtains StrategyType from PyObject (Required for StrategyType to be compatible with Python) +impl<'source> pyo3::FromPyObject<'source> for ProofType { + fn extract(ob: &'source pyo3::PyAny) -> pyo3::PyResult { + let trystr = ::try_from(ob)?; + let strval = trystr.to_string(); + match strval.to_lowercase().as_str() { + "single" => Ok(ProofType::Single), + "for-aggr" => Ok(ProofType::ForAggr), + _ => Err(pyo3::exceptions::PyValueError::new_err( + "Invalid value for ProofType", + )), + } + } +} + +#[allow(missing_docs)] +#[derive(ValueEnum, Copy, Clone, Debug, PartialEq, Eq, Deserialize, Serialize)] +pub enum StrategyType { + Single, + Accum, +} +impl std::fmt::Display for StrategyType { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.to_possible_value() + .expect("no values are skipped") + .get_name() + .fmt(f) + } +} +#[cfg(feature = "python-bindings")] +/// Converts StrategyType into a PyObject (Required for StrategyType to be compatible with Python) +impl pyo3::IntoPy for StrategyType { + fn into_py(self, py: Python) -> PyObject { + match self { + StrategyType::Single => "single".to_object(py), + StrategyType::Accum => "accum".to_object(py), + } + } +} +#[cfg(feature = "python-bindings")] +/// Obtains StrategyType from PyObject (Required for StrategyType to be compatible with Python) +impl<'source> pyo3::FromPyObject<'source> for StrategyType { + fn extract(ob: &'source pyo3::PyAny) -> pyo3::PyResult { + let trystr = ::try_from(ob)?; + let strval = trystr.to_string(); + match strval.to_lowercase().as_str() { + "single" => Ok(StrategyType::Single), + "accum" => Ok(StrategyType::Accum), + _ => Err(pyo3::exceptions::PyValueError::new_err( + "Invalid value for StrategyType", + )), + } + } +} + +#[derive(thisError, Debug)] +/// Errors related to pfsys +pub enum PfSysError { + /// Packing exponent is too large + #[error("largest packing exponent exceeds max. try reducing the scale")] + PackingExponent, +} + +#[allow(missing_docs)] +#[derive(ValueEnum, Copy, Clone, Debug, PartialEq, Eq, Deserialize, Serialize, PartialOrd)] +pub enum TranscriptType { + Poseidon, + EVM, +} + +#[cfg(feature = "python-bindings")] +impl ToPyObject for TranscriptType { + fn to_object(&self, py: Python) -> PyObject { + match self { + TranscriptType::Poseidon => "Poseidon".to_object(py), + TranscriptType::EVM => "EVM".to_object(py), + } + } +} + +#[cfg(feature = "python-bindings")] +/// +pub fn g1affine_to_pydict(g1affine_dict: &PyDict, g1affine: &G1Affine) { + let g1affine_x = field_to_vecu64_montgomery(&g1affine.x); + let g1affine_y = field_to_vecu64_montgomery(&g1affine.y); + g1affine_dict.set_item("x", g1affine_x).unwrap(); + g1affine_dict.set_item("y", g1affine_y).unwrap(); +} + +#[cfg(feature = "python-bindings")] +use halo2curves::bn256::G1; +#[cfg(feature = "python-bindings")] +/// +pub fn g1_to_pydict(g1_dict: &PyDict, g1: &G1) { + let g1_x = field_to_vecu64_montgomery(&g1.x); + let g1_y = field_to_vecu64_montgomery(&g1.y); + let g1_z = field_to_vecu64_montgomery(&g1.z); + g1_dict.set_item("x", g1_x).unwrap(); + g1_dict.set_item("y", g1_y).unwrap(); + g1_dict.set_item("z", g1_z).unwrap(); +} + +/// converts fp into `Vec` in Montgomery form +pub fn field_to_vecu64_montgomery(fp: &F) -> [u64; 4] { + let repr = serde_json::to_string(&fp).unwrap(); + let b: [u64; 4] = serde_json::from_str(&repr).unwrap(); + b +} + +/// converts `Vec` in Montgomery form into fp +pub fn vecu64_to_field_montgomery( + b: &[u64; 4], +) -> F { + let repr = serde_json::to_string(&b).unwrap(); + let fp: F = serde_json::from_str(&repr).unwrap(); + fp +} + +/// An application snark with proof and instance variables ready for aggregation (raw field element) +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Snark +where + C::Scalar: Serialize + DeserializeOwned, + C::ScalarExt: Serialize + DeserializeOwned, +{ + /// the protocol + pub protocol: Option>, + /// public instances of the snark + pub instances: Vec>, + /// the proof + pub proof: Vec, + /// transcript type + pub transcript_type: TranscriptType, + /// the split proof + pub split: Option, +} + +#[cfg(feature = "python-bindings")] +use pyo3::{types::PyDict, PyObject, Python, ToPyObject}; +#[cfg(feature = "python-bindings")] +impl ToPyObject for Snark +where + C::Scalar: Serialize + DeserializeOwned, + C::ScalarExt: Serialize + DeserializeOwned, +{ + fn to_object(&self, py: Python) -> PyObject { + let dict = PyDict::new(py); + let field_elems: Vec> = self + .instances + .iter() + .map(|x| x.iter().map(|fp| field_to_vecu64_montgomery(fp)).collect()) + .collect::>(); + dict.set_item("instances", &field_elems).unwrap(); + let hex_proof = hex::encode(&self.proof); + dict.set_item("proof", &hex_proof).unwrap(); + dict.set_item("transcript_type", &self.transcript_type) + .unwrap(); + dict.to_object(py) + } +} + +impl< + F: PrimeField + SerdeObject + Serialize + FromUniformBytes<64> + DeserializeOwned, + C: CurveAffine + Serialize + DeserializeOwned, + > Snark +where + C::Scalar: Serialize + DeserializeOwned, + C::ScalarExt: Serialize + DeserializeOwned, +{ + /// Create a new application snark from proof and instance variables ready for aggregation + pub fn new( + protocol: PlonkProtocol, + instances: Vec>, + proof: Vec, + transcript_type: TranscriptType, + split: Option, + ) -> Self { + Self { + protocol: Some(protocol), + instances, + proof, + transcript_type, + split, + } + } + + /// Saves the Proof to a specified `proof_path`. + pub fn save(&self, proof_path: &PathBuf) -> Result<(), Box> { + let file = std::fs::File::create(proof_path)?; + let mut writer = BufWriter::new(file); + serde_json::to_writer(&mut writer, &self)?; + Ok(()) + } + + /// Load a json serialized proof from the provided path. + pub fn load>( + proof_path: &PathBuf, + ) -> Result> + where + ::ScalarExt: FromUniformBytes<64>, + { + trace!("reading proof"); + let data = std::fs::read_to_string(proof_path)?; + serde_json::from_str(&data).map_err(|e| e.into()) + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +/// A proof split commit +pub struct ProofSplitCommit { + /// The start index of the output in the witness + start: usize, + /// The end index of the output in the witness + end: usize, +} + +impl From for Option { + fn from(witness: GraphWitness) -> Self { + let mut elem_offset = 0; + + if let Some(input) = witness.processed_inputs { + if let Some(kzg) = input.kzg_commit { + // flatten and count number of elements + let num_elements = kzg.iter().map(|kzg| kzg.len()).sum::(); + + elem_offset += num_elements; + } + } + + if let Some(params) = witness.processed_params { + if let Some(kzg) = params.kzg_commit { + // flatten and count number of elements + let num_elements = kzg.iter().map(|kzg| kzg.len()).sum::(); + + elem_offset += num_elements; + } + } + + if let Some(output) = witness.processed_outputs { + if let Some(kzg) = output.kzg_commit { + // flatten and count number of elements + let num_elements = kzg.iter().map(|kzg| kzg.len()).sum::(); + + Some(ProofSplitCommit { + start: elem_offset, + end: elem_offset + num_elements, + }) + } else { + None + } + } else { + None + } + } +} + +/// An application snark with proof and instance variables ready for aggregation (wrapped field element) +#[derive(Clone, Debug)] +pub struct SnarkWitness { + protocol: Option>, + instances: Vec>>, + proof: Value>, + split: Option, +} + +impl SnarkWitness { + fn without_witnesses(&self) -> Self { + SnarkWitness { + protocol: self.protocol.clone(), + instances: self + .instances + .iter() + .map(|instances| vec![Value::unknown(); instances.len()]) + .collect(), + proof: Value::unknown(), + split: self.split.clone(), + } + } + + fn proof(&self) -> Value<&[u8]> { + self.proof.as_ref().map(Vec::as_slice) + } +} + +impl From> for SnarkWitness +where + C::Scalar: Serialize + DeserializeOwned, + C::ScalarExt: Serialize + DeserializeOwned, +{ + fn from(snark: Snark) -> Self { + Self { + protocol: snark.protocol, + instances: snark + .instances + .into_iter() + .map(|instances| instances.into_iter().map(Value::known).collect()) + .collect(), + proof: Value::known(snark.proof), + split: snark.split, + } + } +} + +/// Creates a [VerifyingKey] and [ProvingKey] for a [crate::graph::GraphCircuit] (`circuit`) with specific [CommitmentScheme] parameters (`params`). +pub fn create_keys>( + circuit: &C, + params: &'_ Scheme::ParamsProver, +) -> Result, halo2_proofs::plonk::Error> +where + C: Circuit, + ::Scalar: FromUniformBytes<64>, +{ + // Real proof + let empty_circuit = >::without_witnesses(circuit); + + // Initialize verifying key + let now = Instant::now(); + trace!("preparing VK"); + let vk = keygen_vk(params, &empty_circuit)?; + let elapsed = now.elapsed(); + info!("VK took {}.{}", elapsed.as_secs(), elapsed.subsec_millis()); + + // Initialize the proving key + let now = Instant::now(); + let pk = keygen_pk(params, vk, &empty_circuit)?; + let elapsed = now.elapsed(); + info!("PK took {}.{}", elapsed.as_secs(), elapsed.subsec_millis()); + Ok(pk) +} + +/// a wrapper around halo2's create_proof +#[allow(clippy::too_many_arguments)] +pub fn create_proof_circuit< + 'params, + Scheme: CommitmentScheme, + F: PrimeField + TensorType, + C: Circuit, + P: Prover<'params, Scheme>, + V: Verifier<'params, Scheme>, + Strategy: VerificationStrategy<'params, Scheme, V>, + E: EncodedChallenge, + TW: TranscriptWriterBuffer, Scheme::Curve, E>, + TR: TranscriptReadBuffer>, Scheme::Curve, E>, +>( + circuit: C, + instances: Vec>, + params: &'params Scheme::ParamsProver, + pk: &ProvingKey, + strategy: Strategy, + check_mode: CheckMode, + transcript_type: TranscriptType, + split: Option, +) -> Result, Box> +where + C: Circuit, + Scheme::ParamsVerifier: 'params, + Scheme::Scalar: Serialize + + DeserializeOwned + + SerdeObject + + PrimeField + + FromUniformBytes<64> + + WithSmallOrderMulGroup<3> + + Ord, + Scheme::Curve: Serialize + DeserializeOwned, +{ + let mut transcript = TranscriptWriterBuffer::<_, Scheme::Curve, _>::init(vec![]); + #[cfg(feature = "det-prove")] + let mut rng = ::from_seed([0u8; 32]); + #[cfg(not(feature = "det-prove"))] + let mut rng = OsRng; + let number_instance = instances.iter().map(|x| x.len()).collect(); + trace!("number_instance {:?}", number_instance); + let protocol = compile( + params, + pk.get_vk(), + Config::kzg().with_num_instance(number_instance), + ); + + let pi_inner = instances + .iter() + .map(|e| e.deref()) + .collect::>(); + let pi_inner: &[&[&[Scheme::Scalar]]] = &[&pi_inner]; + trace!("instances {:?}", instances); + trace!( + "pk num instance column: {:?}", + pk.get_vk().cs().num_instance_columns() + ); + + info!("proof started..."); + // not wasm32 unknown + let now = Instant::now(); + + create_proof::( + params, + pk, + &[circuit], + pi_inner, + &mut rng, + &mut transcript, + )?; + let proof = transcript.finalize(); + + let checkable_pf = Snark::new(protocol, instances, proof, transcript_type, split); + + // sanity check that the generated proof is valid + if check_mode == CheckMode::SAFE { + debug!("verifying generated proof"); + let verifier_params = params.verifier_params(); + verify_proof_circuit::( + &checkable_pf, + verifier_params, + pk.get_vk(), + strategy, + )?; + } + let elapsed = now.elapsed(); + info!( + "proof took {}.{}", + elapsed.as_secs(), + elapsed.subsec_millis() + ); + + Ok(checkable_pf) +} + +/// Swaps the proof commitments to a new set in the proof +pub fn swap_proof_commitments< + F: PrimeField, + Scheme: CommitmentScheme, + E: EncodedChallenge, + TW: TranscriptWriterBuffer, Scheme::Curve, E>, +>( + snark: &Snark, + commitments: &[Scheme::Curve], +) -> Result, Box> +where + Scheme::Scalar: SerdeObject + + PrimeField + + FromUniformBytes<64> + + WithSmallOrderMulGroup<3> + + Ord + + Serialize + + DeserializeOwned, + Scheme::Curve: Serialize + DeserializeOwned, +{ + let mut transcript_new: TW = TranscriptWriterBuffer::<_, Scheme::Curve, _>::init(vec![]); + + // kzg commitments are the first set of points in the proof, this we'll always be the first set of advice + for commit in commitments { + transcript_new + .write_point(*commit) + .map_err(|_| "failed to write point")?; + } + + let proof_first_bytes = transcript_new.finalize(); + + let mut snark_new = snark.clone(); + // swap the proof bytes for the new ones + snark_new.proof[..proof_first_bytes.len()].copy_from_slice(&proof_first_bytes); + + Ok(snark_new) +} + +/// Swap the proof commitments to a new set in the proof for KZG +pub fn swap_proof_commitments_kzg( + snark: &Snark, + commitments: &[G1Affine], +) -> Result, Box> { + let proof = match snark.transcript_type { + TranscriptType::EVM => swap_proof_commitments::< + Fr, + KZGCommitmentScheme, + _, + EvmTranscript, + >(snark, commitments)?, + TranscriptType::Poseidon => swap_proof_commitments::< + Fr, + KZGCommitmentScheme, + _, + PoseidonTranscript, + >(snark, commitments)?, + }; + Ok(proof) +} + +/// A wrapper around halo2's verify_proof +pub fn verify_proof_circuit< + 'params, + F: PrimeField, + V: Verifier<'params, Scheme>, + Scheme: CommitmentScheme, + Strategy: VerificationStrategy<'params, Scheme, V>, + E: EncodedChallenge, + TR: TranscriptReadBuffer>, Scheme::Curve, E>, +>( + snark: &Snark, + params: &'params Scheme::ParamsVerifier, + vk: &VerifyingKey, + strategy: Strategy, +) -> Result +where + Scheme::Scalar: SerdeObject + + PrimeField + + FromUniformBytes<64> + + WithSmallOrderMulGroup<3> + + Ord + + Serialize + + DeserializeOwned, + Scheme::Curve: Serialize + DeserializeOwned, +{ + let pi_inner = snark + .instances + .iter() + .map(|e| e.deref()) + .collect::>(); + let instances: &[&[&[Scheme::Scalar]]] = &[&pi_inner]; + trace!("instances {:?}", instances); + + let mut transcript = TranscriptReadBuffer::init(Cursor::new(snark.proof.clone())); + verify_proof::(params, vk, strategy, instances, &mut transcript) +} + +/// Loads a [VerifyingKey] at `path`. +pub fn load_vk>( + path: PathBuf, + params: >::Params, +) -> Result, Box> +where + C: Circuit, + Scheme::Curve: SerdeObject + CurveAffine, + Scheme::Scalar: PrimeField + SerdeObject + FromUniformBytes<64>, +{ + info!("loading verification key from {:?}", path); + let f = + File::open(path.clone()).map_err(|_| format!("failed to load vk at {}", path.display()))?; + let mut reader = BufReader::new(f); + VerifyingKey::::read::<_, C>( + &mut reader, + halo2_proofs::SerdeFormat::RawBytes, + params, + ) + .map_err(Box::::from) +} + +/// Loads a [ProvingKey] at `path`. +pub fn load_pk>( + path: PathBuf, + params: >::Params, +) -> Result, Box> +where + C: Circuit, + Scheme::Curve: SerdeObject + CurveAffine, + Scheme::Scalar: PrimeField + SerdeObject + FromUniformBytes<64>, +{ + info!("loading proving key from {:?}", path); + let f = + File::open(path.clone()).map_err(|_| format!("failed to load pk at {}", path.display()))?; + let mut reader = BufReader::new(f); + ProvingKey::::read::<_, C>( + &mut reader, + halo2_proofs::SerdeFormat::RawBytes, + params, + ) + .map_err(Box::::from) +} + +/// Saves a [ProvingKey] to `path`. +pub fn save_pk( + path: &PathBuf, + vk: &ProvingKey, +) -> Result<(), io::Error> +where + Scheme::Curve: SerdeObject + CurveAffine, + Scheme::Scalar: PrimeField + SerdeObject + FromUniformBytes<64>, +{ + info!("saving proving key 💾"); + let f = File::create(path)?; + let mut writer = BufWriter::new(f); + vk.write(&mut writer, halo2_proofs::SerdeFormat::RawBytes)?; + writer.flush()?; + Ok(()) +} + +/// Saves a [VerifyingKey] to `path`. +pub fn save_vk( + path: &PathBuf, + vk: &VerifyingKey, +) -> Result<(), io::Error> +where + Scheme::Curve: SerdeObject + CurveAffine, + Scheme::Scalar: PrimeField + SerdeObject + FromUniformBytes<64>, +{ + info!("saving verification key 💾"); + let f = File::create(path)?; + let mut writer = BufWriter::new(f); + vk.write(&mut writer, halo2_proofs::SerdeFormat::RawBytes)?; + writer.flush()?; + Ok(()) +} + +/// Saves [CommitmentScheme] parameters to `path`. +pub fn save_params( + path: &PathBuf, + params: &'_ Scheme::ParamsVerifier, +) -> Result<(), io::Error> { + info!("saving parameters 💾"); + let f = File::create(path)?; + let mut writer = BufWriter::new(f); + params.write(&mut writer)?; + writer.flush()?; + Ok(()) +} + +/// helper function +#[allow(clippy::too_many_arguments)] +pub fn create_proof_circuit_kzg< + 'params, + C: Circuit, + Strategy: VerificationStrategy<'params, KZGCommitmentScheme, VerifierSHPLONK<'params, Bn256>>, +>( + circuit: C, + params: &'params ParamsKZG, + public_inputs: Option>, + pk: &ProvingKey, + transcript: TranscriptType, + strategy: Strategy, + check_mode: CheckMode, + split: Option, +) -> Result, Box> { + let public_inputs = if let Some(public_inputs) = public_inputs { + if !public_inputs.is_empty() { + vec![public_inputs] + } else { + vec![vec![]] + } + } else { + vec![] + }; + + match transcript { + TranscriptType::EVM => create_proof_circuit::< + KZGCommitmentScheme<_>, + Fr, + _, + ProverSHPLONK<_>, + VerifierSHPLONK<_>, + _, + _, + EvmTranscript, + EvmTranscript, + >( + circuit, + public_inputs, + params, + pk, + strategy, + check_mode, + transcript, + split, + ) + .map_err(Box::::from), + TranscriptType::Poseidon => create_proof_circuit::< + KZGCommitmentScheme<_>, + Fr, + _, + ProverSHPLONK<_>, + VerifierSHPLONK<_>, + _, + _, + PoseidonTranscript, + PoseidonTranscript, + >( + circuit, + public_inputs, + params, + pk, + strategy, + check_mode, + transcript, + split, + ) + .map_err(Box::::from), + } +} + +#[allow(unused)] +/// helper function +pub(crate) fn verify_proof_circuit_kzg< + 'params, + Strategy: VerificationStrategy<'params, KZGCommitmentScheme, VerifierSHPLONK<'params, Bn256>>, +>( + params: &'params ParamsKZG, + proof: Snark, + vk: &VerifyingKey, + strategy: Strategy, +) -> Result { + match proof.transcript_type { + TranscriptType::EVM => verify_proof_circuit::< + Fr, + VerifierSHPLONK<'_, Bn256>, + _, + _, + _, + EvmTranscript, + >(&proof, params, vk, strategy), + TranscriptType::Poseidon => verify_proof_circuit::< + Fr, + VerifierSHPLONK<'_, Bn256>, + _, + _, + _, + PoseidonTranscript, + >(&proof, params, vk, strategy), + } +} + +//////////////////////// + +#[cfg(test)] +#[cfg(not(target_arch = "wasm32"))] +mod tests { + use std::io::copy; + + use super::*; + use halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme; + use halo2curves::bn256::{Bn256, Fr, G1Affine}; + use tempfile::Builder; + + #[tokio::test] + async fn test_can_load_pre_generated_srs() { + let tmp_dir = Builder::new().prefix("example").tempdir().unwrap(); + // lets hope this link never rots + let target = "https://trusted-setup-halo2kzg.s3.eu-central-1.amazonaws.com/hermez-raw-1"; + let response = reqwest::get(target).await.unwrap(); + + let fname = response + .url() + .path_segments() + .and_then(|segments| segments.last()) + .and_then(|name| if name.is_empty() { None } else { Some(name) }) + .unwrap_or("tmp.bin"); + + info!("file to download: '{}'", fname); + let fname = tmp_dir.path().join(fname); + info!("will be located under: '{:?}'", fname); + let mut dest = File::create(fname.clone()).unwrap(); + let content = response.bytes().await.unwrap(); + copy(&mut &content[..], &mut dest).unwrap(); + let res = srs::load_srs::>(fname); + assert!(res.is_ok()) + } + + #[tokio::test] + async fn test_can_load_saved_srs() { + let tmp_dir = Builder::new().prefix("example").tempdir().unwrap(); + let fname = tmp_dir.path().join("kzg.params"); + let srs = srs::gen_srs::>(1); + let res = save_params::>(&fname, &srs); + assert!(res.is_ok()); + let res = srs::load_srs::>(fname); + assert!(res.is_ok()) + } + + #[test] + fn test_snark_serialization_roundtrip() { + let snark = Snark:: { + proof: vec![1, 2, 3, 4, 5, 6, 7, 8], + instances: vec![vec![Fr::from(1)], vec![Fr::from(2)]], + transcript_type: TranscriptType::EVM, + protocol: None, + split: None, + }; + + snark + .save(&"test_snark_serialization_roundtrip.json".into()) + .unwrap(); + let snark2 = Snark::::load::>( + &"test_snark_serialization_roundtrip.json".into(), + ) + .unwrap(); + assert_eq!(snark.instances, snark2.instances); + assert_eq!(snark.proof, snark2.proof); + assert_eq!(snark.transcript_type, snark2.transcript_type); + } +} diff --git a/mnist_ezkl/src/pfsys/srs.rs b/mnist_ezkl/src/pfsys/srs.rs new file mode 100644 index 0000000..dc54a1a --- /dev/null +++ b/mnist_ezkl/src/pfsys/srs.rs @@ -0,0 +1,28 @@ +use halo2_proofs::poly::commitment::CommitmentScheme; +use halo2_proofs::poly::commitment::Params; +use halo2_proofs::poly::commitment::ParamsProver; +use log::info; +use std::error::Error; +use std::fs::File; +use std::io::BufReader; +use std::path::PathBuf; + +/// for now we use the urls of the powers of tau ceremony from +pub const PUBLIC_SRS_URL: &str = + "https://trusted-setup-halo2kzg.s3.eu-central-1.amazonaws.com/perpetual-powers-of-tau-raw-"; + +/// Helper function for generating SRS. Only use for testing +pub fn gen_srs(k: u32) -> Scheme::ParamsProver { + Scheme::ParamsProver::new(k) +} + +/// Loads the [CommitmentScheme::ParamsVerifier] at `path`. +pub fn load_srs( + path: PathBuf, +) -> Result> { + info!("loading srs from {:?}", path); + let f = File::open(path.clone()) + .map_err(|_| format!("failed to load srs at {}", path.display()))?; + let mut reader = BufReader::new(f); + Params::<'_, Scheme::Curve>::read(&mut reader).map_err(Box::::from) +} diff --git a/mnist_ezkl/src/python.rs b/mnist_ezkl/src/python.rs new file mode 100644 index 0000000..ae92ce3 --- /dev/null +++ b/mnist_ezkl/src/python.rs @@ -0,0 +1,1302 @@ +use crate::circuit::modules::elgamal::{ElGamalCipher, ElGamalVariables}; +use crate::circuit::modules::kzg::KZGChip; +use crate::circuit::modules::poseidon::{ + spec::{PoseidonSpec, POSEIDON_RATE, POSEIDON_WIDTH}, + PoseidonChip, +}; +use crate::circuit::modules::Module; +use crate::circuit::{CheckMode, Tolerance}; +use crate::commands::CalibrationTarget; +use crate::fieldutils::{felt_to_i128, i128_to_felt}; +use crate::graph::modules::POSEIDON_LEN_GRAPH; +use crate::graph::{ + quantize_float, scale_to_multiplier, GraphCircuit, GraphSettings, Model, Visibility, +}; +use crate::pfsys::evm::aggregation::AggregationCircuit; +use crate::pfsys::{ + load_pk, load_vk, save_params, save_vk, srs::gen_srs as ezkl_gen_srs, srs::load_srs, ProofType, + Snark, TranscriptType, +}; +use crate::RunArgs; +use ethers::types::H160; +use halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme; +use halo2curves::bn256::{Bn256, Fq, Fr, G1Affine, G1}; +use pyo3::exceptions::{PyIOError, PyRuntimeError}; +use pyo3::prelude::*; +use pyo3::wrap_pyfunction; +use pyo3_log; +use rand::rngs::StdRng; +use rand::SeedableRng; +use snark_verifier::util::arithmetic::PrimeField; +use std::str::FromStr; +use std::{fs::File, path::PathBuf}; +use tokio::runtime::Runtime; + +type PyFelt = [u64; 4]; + +/// pyclass containing the struct used for G1 +#[pyclass] +#[derive(Debug, Clone)] +struct PyG1 { + #[pyo3(get, set)] + x: PyFelt, + #[pyo3(get, set)] + y: PyFelt, + #[pyo3(get, set)] + z: PyFelt, +} + +impl From for PyG1 { + fn from(g1: G1) -> Self { + PyG1 { + x: crate::pfsys::field_to_vecu64_montgomery::(&g1.x), + y: crate::pfsys::field_to_vecu64_montgomery::(&g1.y), + z: crate::pfsys::field_to_vecu64_montgomery::(&g1.z), + } + } +} + +impl Into for PyG1 { + fn into(self) -> G1 { + G1 { + x: crate::pfsys::vecu64_to_field_montgomery::(&self.x), + y: crate::pfsys::vecu64_to_field_montgomery::(&self.y), + z: crate::pfsys::vecu64_to_field_montgomery::(&self.z), + } + } +} + +impl pyo3::ToPyObject for PyG1 { + fn to_object(&self, py: pyo3::Python) -> pyo3::PyObject { + let g1_dict = pyo3::types::PyDict::new(py); + + g1_dict.set_item("x", self.x.to_object(py)).unwrap(); + g1_dict.set_item("y", self.y.to_object(py)).unwrap(); + g1_dict.set_item("z", self.z.to_object(py)).unwrap(); + g1_dict.into() + } +} + +/// pyclass containing the struct used for G1 +#[pyclass] +#[derive(Debug, Clone)] +pub struct PyG1Affine { + #[pyo3(get, set)] + /// + pub x: PyFelt, + #[pyo3(get, set)] + /// + pub y: PyFelt, +} + +impl From for PyG1Affine { + fn from(g1: G1Affine) -> Self { + PyG1Affine { + x: crate::pfsys::field_to_vecu64_montgomery::(&g1.x), + y: crate::pfsys::field_to_vecu64_montgomery::(&g1.y), + } + } +} + +impl Into for PyG1Affine { + fn into(self) -> G1Affine { + G1Affine { + x: crate::pfsys::vecu64_to_field_montgomery::(&self.x), + y: crate::pfsys::vecu64_to_field_montgomery::(&self.y), + } + } +} + +impl pyo3::ToPyObject for PyG1Affine { + fn to_object(&self, py: pyo3::Python) -> pyo3::PyObject { + let g1_dict = pyo3::types::PyDict::new(py); + + g1_dict.set_item("x", self.x.to_object(py)).unwrap(); + g1_dict.set_item("y", self.y.to_object(py)).unwrap(); + g1_dict.into() + } +} + +/// pyclass containing the struct used for ElgamalCipher +#[pyclass] +#[derive(Debug, Clone)] +pub struct PyElGamalCipher { + #[pyo3(get, set)] + /// + c1: PyG1, + #[pyo3(get, set)] + /// + c2: Vec, +} + +impl From for ElGamalCipher { + fn from(py_elgamal_cipher: PyElGamalCipher) -> Self { + ElGamalCipher { + c1: py_elgamal_cipher.c1.into(), + c2: py_elgamal_cipher + .c2 + .iter() + .map(|x| crate::pfsys::vecu64_to_field_montgomery::(&x)) + .collect::>(), + } + } +} + +impl From for PyElGamalCipher { + fn from(elgamal_cipher: ElGamalCipher) -> Self { + PyElGamalCipher { + c1: elgamal_cipher.c1.into(), + c2: elgamal_cipher + .c2 + .iter() + .map(|x| crate::pfsys::field_to_vecu64_montgomery::(&x)) + .collect::>(), + } + } +} + +/// pyclass containing the struct used for ElgamalVariables +#[pyclass] +#[derive(Debug, Clone)] +pub struct PyElGamalVariables { + #[pyo3(get, set)] + r: PyFelt, + #[pyo3(get, set)] + pk: PyG1Affine, + #[pyo3(get, set)] + sk: PyFelt, + #[pyo3(get, set)] + window_size: usize, + #[pyo3(get, set)] + aux_generator: PyG1Affine, +} + +impl From for ElGamalVariables { + fn from(py_elgamal_variables: PyElGamalVariables) -> Self { + ElGamalVariables { + r: crate::pfsys::vecu64_to_field_montgomery::(&py_elgamal_variables.r), + pk: G1Affine { + x: crate::pfsys::vecu64_to_field_montgomery::(&py_elgamal_variables.pk.x), + y: crate::pfsys::vecu64_to_field_montgomery::(&py_elgamal_variables.pk.y), + }, + sk: crate::pfsys::vecu64_to_field_montgomery::(&py_elgamal_variables.sk), + window_size: py_elgamal_variables.window_size, + aux_generator: G1Affine { + x: crate::pfsys::vecu64_to_field_montgomery::( + &py_elgamal_variables.aux_generator.x, + ), + y: crate::pfsys::vecu64_to_field_montgomery::( + &py_elgamal_variables.aux_generator.y, + ), + }, + } + } +} + +impl From for PyElGamalVariables { + fn from(elgamal_variables: ElGamalVariables) -> Self { + PyElGamalVariables { + r: crate::pfsys::field_to_vecu64_montgomery::(&elgamal_variables.r), + pk: PyG1Affine { + x: crate::pfsys::field_to_vecu64_montgomery::(&elgamal_variables.pk.x), + y: crate::pfsys::field_to_vecu64_montgomery::(&elgamal_variables.pk.y), + }, + sk: crate::pfsys::field_to_vecu64_montgomery::(&elgamal_variables.sk), + window_size: elgamal_variables.window_size, + aux_generator: PyG1Affine { + x: crate::pfsys::field_to_vecu64_montgomery::( + &elgamal_variables.aux_generator.x, + ), + y: crate::pfsys::field_to_vecu64_montgomery::( + &elgamal_variables.aux_generator.y, + ), + }, + } + } +} + +impl pyo3::ToPyObject for PyElGamalVariables { + fn to_object(&self, py: pyo3::Python) -> pyo3::PyObject { + let variables_dict = pyo3::types::PyDict::new(py); + + variables_dict.set_item("r", self.r.to_object(py)).unwrap(); + variables_dict + .set_item("pk", self.pk.to_object(py)) + .unwrap(); + variables_dict + .set_item("sk", self.sk.to_object(py)) + .unwrap(); + variables_dict + .set_item("window_size", self.window_size.to_object(py)) + .unwrap(); + variables_dict + .set_item("aux_generator", self.aux_generator.to_object(py)) + .unwrap(); + variables_dict.into() + } +} + +/// pyclass containing the struct used for run_args +#[pyclass] +#[derive(Clone)] +struct PyRunArgs { + #[pyo3(get, set)] + pub tolerance: f32, + #[pyo3(get, set)] + pub input_scale: crate::Scale, + #[pyo3(get, set)] + pub param_scale: crate::Scale, + #[pyo3(get, set)] + pub scale_rebase_multiplier: u32, + #[pyo3(get, set)] + pub lookup_range: (i128, i128), + #[pyo3(get, set)] + pub logrows: u32, + #[pyo3(get, set)] + pub num_inner_cols: usize, + #[pyo3(get, set)] + pub input_visibility: Visibility, + #[pyo3(get, set)] + pub output_visibility: Visibility, + #[pyo3(get, set)] + pub param_visibility: Visibility, + #[pyo3(get, set)] + pub variables: Vec<(String, usize)>, +} + +/// default instantiation of PyRunArgs +#[pymethods] +impl PyRunArgs { + #[new] + fn new() -> Self { + PyRunArgs { + tolerance: 0.0, + input_scale: 7, + param_scale: 7, + scale_rebase_multiplier: 1, + num_inner_cols: 2, + lookup_range: (-32768, 32768), + logrows: 17, + input_visibility: Visibility::Private, + output_visibility: Visibility::Public, + param_visibility: Visibility::Private, + variables: vec![("batch_size".to_string(), 1)], + } + } +} + +/// Conversion between PyRunArgs and RunArgs +impl From for RunArgs { + fn from(py_run_args: PyRunArgs) -> Self { + RunArgs { + tolerance: Tolerance::from(py_run_args.tolerance), + input_scale: py_run_args.input_scale, + param_scale: py_run_args.param_scale, + num_inner_cols: py_run_args.num_inner_cols, + scale_rebase_multiplier: py_run_args.scale_rebase_multiplier, + lookup_range: py_run_args.lookup_range, + logrows: py_run_args.logrows, + input_visibility: py_run_args.input_visibility, + output_visibility: py_run_args.output_visibility, + param_visibility: py_run_args.param_visibility, + variables: py_run_args.variables, + } + } +} + +impl Into for RunArgs { + fn into(self) -> PyRunArgs { + PyRunArgs { + tolerance: self.tolerance.val.into(), + input_scale: self.input_scale, + param_scale: self.param_scale, + num_inner_cols: self.num_inner_cols, + scale_rebase_multiplier: self.scale_rebase_multiplier, + lookup_range: self.lookup_range, + logrows: self.logrows, + input_visibility: self.input_visibility, + output_visibility: self.output_visibility, + param_visibility: self.param_visibility, + variables: self.variables, + } + } +} + +/// Converts 4 u64s to a field element +#[pyfunction(signature = ( + array, +))] +fn vecu64_to_felt(array: PyFelt) -> PyResult { + Ok(format!( + "{:?}", + crate::pfsys::vecu64_to_field_montgomery::(&array) + )) +} + +/// Converts 4 u64s representing a field element directly to an integer +#[pyfunction(signature = ( + array, +))] +fn vecu64_to_int(array: PyFelt) -> PyResult { + let felt = crate::pfsys::vecu64_to_field_montgomery::(&array); + let int_rep = felt_to_i128(felt); + Ok(int_rep) +} + +/// Converts 4 u64s representing a field element directly to a (rescaled from fixed point scaling) floating point +#[pyfunction(signature = ( + array, + scale +))] +fn vecu64_to_float(array: PyFelt, scale: crate::Scale) -> PyResult { + let felt = crate::pfsys::vecu64_to_field_montgomery::(&array); + let int_rep = felt_to_i128(felt); + let multiplier = scale_to_multiplier(scale); + let float_rep = int_rep as f64 / multiplier; + Ok(float_rep) +} + +/// Converts a floating point element to 4 u64s representing a fixed point field element +#[pyfunction(signature = ( +input, +scale +))] +fn float_to_vecu64(input: f64, scale: crate::Scale) -> PyResult { + let int_rep = quantize_float(&input, 0.0, scale) + .map_err(|_| PyIOError::new_err("Failed to quantize input"))?; + let felt = i128_to_felt(int_rep); + Ok(crate::pfsys::field_to_vecu64_montgomery::(&felt)) +} + +/// Converts a buffer to vector of 4 u64s representing a fixed point field element +#[pyfunction(signature = ( + buffer + ))] +fn buffer_to_felts(buffer: Vec) -> PyResult> { + fn u8_array_to_u128_le(arr: [u8; 16]) -> u128 { + let mut n: u128 = 0; + for &b in arr.iter().rev() { + n <<= 8; + n |= b as u128; + } + n + } + + let buffer = &buffer[..]; + + // Divide the buffer into chunks of 64 bytes + let chunks = buffer.chunks_exact(16); + + // Get the remainder + let remainder = chunks.remainder(); + + // Add 0s to the remainder to make it 64 bytes + let mut remainder = remainder.to_vec(); + + // Collect chunks into a Vec<[u8; 16]>. + let chunks: Result, PyErr> = chunks + .map(|slice| { + let array: [u8; 16] = slice + .try_into() + .map_err(|_| PyIOError::new_err("Failed to slice input buffer"))?; + Ok(array) + }) + .collect(); + + let mut chunks = chunks?; + + if remainder.len() != 0 { + remainder.resize(16, 0); + // Convert the Vec to [u8; 16] + let remainder_array: [u8; 16] = remainder + .try_into() + .map_err(|_| PyIOError::new_err("Failed to slice remainder"))?; + // append the remainder to the chunks + chunks.push(remainder_array); + } + + // Convert each chunk to a field element + let field_elements: Vec = chunks + .iter() + .map(|x| PrimeField::from_u128(u8_array_to_u128_le(*x))) + .collect(); + + let field_elements: Vec = field_elements.iter().map(|x| format!("{:?}", x)).collect(); + + Ok(field_elements) +} + +/// Generate a poseidon hash. +#[pyfunction(signature = ( + message, + ))] +fn poseidon_hash(message: Vec) -> PyResult> { + let message: Vec = message + .iter() + .map(|x| crate::pfsys::vecu64_to_field_montgomery::(&x)) + .collect::>(); + + let output = + PoseidonChip::::run( + message.clone(), + ) + .map_err(|_| PyIOError::new_err("Failed to run poseidon"))?; + + let hash = output[0] + .iter() + .map(|x| crate::pfsys::field_to_vecu64_montgomery::(&x)) + .collect::>(); + Ok(hash) +} + +/// Generate a kzg commitment. +#[pyfunction(signature = ( + message, + srs_path, + vk_path, + settings_path + ))] +fn kzg_commit( + message: Vec, + srs_path: PathBuf, + vk_path: PathBuf, + settings_path: PathBuf, +) -> PyResult> { + let message: Vec = message + .iter() + .map(|x| crate::pfsys::vecu64_to_field_montgomery::(&x)) + .collect::>(); + + let srs = load_srs::>(srs_path) + .map_err(|_| PyIOError::new_err("Failed to load srs"))?; + + let settings = GraphSettings::load(&settings_path) + .map_err(|_| PyIOError::new_err("Failed to load circuit settings"))?; + + let vk = load_vk::, Fr, GraphCircuit>(vk_path, settings) + .map_err(|_| PyIOError::new_err("Failed to load vk"))?; + + let output = KZGChip::commit( + message, + vk.cs().degree() as u32, + (vk.cs().blinding_factors() + 1) as u32, + &srs, + ); + + Ok(output.iter().map(|x| (*x).into()).collect::>()) +} + +/// Swap the commitments in a proof +#[pyfunction(signature = ( + proof_path, + witness_path, + ))] +fn swap_proof_commitments(proof_path: PathBuf, witness_path: PathBuf) -> PyResult<()> { + crate::execute::swap_proof_commitments(proof_path, witness_path) + .map_err(|_| PyIOError::new_err("Failed to swap commitments"))?; + + Ok(()) +} + +/// Encrypt using elgamal +#[pyfunction(signature = ( + pk, message, r + ))] +pub fn elgamal_encrypt( + pk: PyG1Affine, + message: Vec, + r: PyFelt, +) -> PyResult { + let pk: G1Affine = pk.into(); + let message = message + .iter() + .map(|x| crate::pfsys::vecu64_to_field_montgomery::(&x)) + .collect::>(); + let r = crate::pfsys::vecu64_to_field_montgomery::(&r); + + let output = crate::circuit::modules::elgamal::ElGamalGadget::encrypt(pk, message, r); + Ok(output.into()) +} + +/// Decrypt using elgamal +#[pyfunction(signature = ( + cipher, sk + ))] +pub fn elgamal_decrypt(cipher: PyElGamalCipher, sk: PyFelt) -> PyResult> { + let sk: Fr = crate::pfsys::vecu64_to_field_montgomery::(&sk); + + let output = crate::circuit::modules::elgamal::ElGamalGadget::decrypt(&cipher.into(), sk); + + let output = output + .iter() + .map(|x| crate::pfsys::field_to_vecu64_montgomery::(&x)) + .collect::>(); + + Ok(output) +} + +/// Generates random elgamal variables from a random seed value in browser. +/// Make sure input seed comes a secure source of randomness +#[pyfunction(signature = ( + rng + ))] +pub fn elgamal_gen_random(rng: Vec) -> PyResult { + let seed: &[u8] = &rng; + let mut rng = StdRng::from_seed( + seed.try_into() + .map_err(|_| PyIOError::new_err("Failed to create random seed"))?, + ); + + let output = crate::circuit::modules::elgamal::ElGamalVariables::gen_random(&mut rng); + + Ok(output.into()) +} + +/// Generates a vk from a pk for a model circuit and saves it to a file +#[pyfunction(signature = ( + path_to_pk, + circuit_settings_path, + vk_output_path + ))] +fn gen_vk_from_pk_single( + path_to_pk: PathBuf, + circuit_settings_path: PathBuf, + vk_output_path: PathBuf, +) -> PyResult { + let settings = GraphSettings::load(&circuit_settings_path) + .map_err(|_| PyIOError::new_err("Failed to load circuit settings"))?; + + let pk = load_pk::, Fr, GraphCircuit>(path_to_pk, settings) + .map_err(|_| PyIOError::new_err("Failed to load pk"))?; + + let vk = pk.get_vk(); + + // now save + save_vk::>(&vk_output_path, vk) + .map_err(|_| PyIOError::new_err("Failed to save vk"))?; + + Ok(true) +} + +/// Generates a vk from a pk for an aggregate circuit and saves it to a file +#[pyfunction(signature = ( + path_to_pk, + vk_output_path + ))] +fn gen_vk_from_pk_aggr(path_to_pk: PathBuf, vk_output_path: PathBuf) -> PyResult { + let pk = load_pk::, Fr, AggregationCircuit>(path_to_pk, ()) + .map_err(|_| PyIOError::new_err("Failed to load pk"))?; + + let vk = pk.get_vk(); + + // now save + save_vk::>(&vk_output_path, vk) + .map_err(|_| PyIOError::new_err("Failed to save vk"))?; + + Ok(true) +} + +/// Displays the table as a string in python +#[pyfunction(signature = ( + model, + py_run_args = None +))] +fn table(model: String, py_run_args: Option) -> PyResult { + let run_args: RunArgs = py_run_args.unwrap_or_else(PyRunArgs::new).into(); + let mut reader = File::open(model).map_err(|_| PyIOError::new_err("Failed to open model"))?; + let result = Model::new(&mut reader, &run_args); + + match result { + Ok(m) => Ok(m.table_nodes()), + Err(_) => Err(PyIOError::new_err("Failed to import model")), + } +} + +/// generates the srs +#[pyfunction(signature = ( + srs_path, + logrows, +))] +fn gen_srs(srs_path: PathBuf, logrows: usize) -> PyResult<()> { + let params = ezkl_gen_srs::>(logrows as u32); + save_params::>(&srs_path, ¶ms)?; + Ok(()) +} + +/// gets a public srs +#[pyfunction(signature = ( + srs_path, + settings_path=None, + logrows=None, +))] +fn get_srs( + srs_path: PathBuf, + settings_path: Option, + logrows: Option, +) -> PyResult { + Runtime::new() + .unwrap() + .block_on(crate::execute::get_srs_cmd( + srs_path, + settings_path, + logrows, + CheckMode::SAFE, + )) + .map_err(|e| { + let err_str = format!("Failed to get srs: {}", e); + PyRuntimeError::new_err(err_str) + })?; + Ok(true) +} + +/// generates the circuit settings +#[pyfunction(signature = ( + model, + output, + py_run_args = None, +))] +fn gen_settings( + model: PathBuf, + output: PathBuf, + py_run_args: Option, +) -> Result { + let run_args: RunArgs = py_run_args.unwrap_or_else(PyRunArgs::new).into(); + + crate::execute::gen_circuit_settings(model, output, run_args).map_err(|e| { + let err_str = format!("Failed to generate settings: {}", e); + PyRuntimeError::new_err(err_str) + })?; + + Ok(true) +} + +/// calibrates the circuit settings +#[pyfunction(signature = ( + data, + model, + settings, + target, + scales = None, + max_logrows = None, +))] +fn calibrate_settings( + data: PathBuf, + model: PathBuf, + settings: PathBuf, + target: Option, + scales: Option>, + max_logrows: Option, +) -> Result { + let target = target.unwrap_or(CalibrationTarget::Resources { + col_overflow: false, + }); + crate::execute::calibrate(model, data, settings, target, scales, max_logrows).map_err(|e| { + let err_str = format!("Failed to calibrate settings: {}", e); + PyRuntimeError::new_err(err_str) + })?; + + Ok(true) +} + +/// runs the forward pass operation +#[pyfunction(signature = ( + data, + model, + output, + vk_path=None, + srs_path=None, +))] +fn gen_witness( + data: PathBuf, + model: PathBuf, + output: Option, + vk_path: Option, + srs_path: Option, +) -> PyResult { + let output = Runtime::new() + .unwrap() + .block_on(crate::execute::gen_witness( + model, data, output, vk_path, srs_path, + )) + .map_err(|e| { + let err_str = format!("Failed to run generate witness: {}", e); + PyRuntimeError::new_err(err_str) + })?; + Python::with_gil(|py| Ok(output.to_object(py))) +} + +/// mocks the prover +#[pyfunction(signature = ( + witness, + model, +))] +fn mock(witness: PathBuf, model: PathBuf) -> PyResult { + crate::execute::mock(model, witness).map_err(|e| { + let err_str = format!("Failed to run mock: {}", e); + PyRuntimeError::new_err(err_str) + })?; + Ok(true) +} + +/// mocks the aggregate prover +#[pyfunction(signature = ( + aggregation_snarks, + logrows, + split_proofs = false, +))] +fn mock_aggregate( + aggregation_snarks: Vec, + logrows: u32, + split_proofs: bool, +) -> PyResult { + crate::execute::mock_aggregate(aggregation_snarks, logrows, split_proofs).map_err(|e| { + let err_str = format!("Failed to run mock: {}", e); + PyRuntimeError::new_err(err_str) + })?; + + Ok(true) +} + +/// runs the prover on a set of inputs +#[pyfunction(signature = ( + model, + vk_path, + pk_path, + srs_path, + witness_path = None +))] +fn setup( + model: PathBuf, + vk_path: PathBuf, + pk_path: PathBuf, + srs_path: PathBuf, + witness_path: Option, +) -> Result { + crate::execute::setup(model, srs_path, vk_path, pk_path, witness_path).map_err(|e| { + let err_str = format!("Failed to run setup: {}", e); + PyRuntimeError::new_err(err_str) + })?; + + Ok(true) +} + +/// runs the prover on a set of inputs +#[pyfunction(signature = ( + witness, + model, + pk_path, + proof_path, + srs_path, + proof_type, +))] +fn prove( + witness: PathBuf, + model: PathBuf, + pk_path: PathBuf, + proof_path: Option, + srs_path: PathBuf, + proof_type: ProofType, +) -> PyResult { + let snark = crate::execute::prove( + witness, + model, + pk_path, + proof_path, + srs_path, + proof_type, + CheckMode::UNSAFE, + ) + .map_err(|e| { + let err_str = format!("Failed to run prove: {}", e); + PyRuntimeError::new_err(err_str) + })?; + + Python::with_gil(|py| Ok(snark.to_object(py))) +} + +/// verifies a given proof +#[pyfunction(signature = ( + proof_path, + settings_path, + vk_path, + srs_path, +))] +fn verify( + proof_path: PathBuf, + settings_path: PathBuf, + vk_path: PathBuf, + srs_path: PathBuf, +) -> Result { + crate::execute::verify(proof_path, settings_path, vk_path, srs_path).map_err(|e| { + let err_str = format!("Failed to run verify: {}", e); + PyRuntimeError::new_err(err_str) + })?; + + Ok(true) +} + +#[pyfunction(signature = ( + sample_snarks, + vk_path, + pk_path, + srs_path, + logrows, + split_proofs = false, +))] +fn setup_aggregate( + sample_snarks: Vec, + vk_path: PathBuf, + pk_path: PathBuf, + srs_path: PathBuf, + logrows: u32, + split_proofs: bool, +) -> Result { + crate::execute::setup_aggregate( + sample_snarks, + vk_path, + pk_path, + srs_path, + logrows, + split_proofs, + ) + .map_err(|e| { + let err_str = format!("Failed to setup aggregate: {}", e); + PyRuntimeError::new_err(err_str) + })?; + + Ok(true) +} + +#[pyfunction(signature = ( + model, + compiled_circuit, + settings_path, +))] +fn compile_circuit( + model: PathBuf, + compiled_circuit: PathBuf, + settings_path: PathBuf, +) -> Result { + crate::execute::compile_circuit(model, compiled_circuit, settings_path).map_err(|e| { + let err_str = format!("Failed to setup aggregate: {}", e); + PyRuntimeError::new_err(err_str) + })?; + + Ok(true) +} + +/// creates an aggregated proof +#[pyfunction(signature = ( + proof_path, + aggregation_snarks, + vk_path, + srs_path, + transcript, + logrows, + check_mode, + split_proofs = false, +))] +fn aggregate( + proof_path: PathBuf, + aggregation_snarks: Vec, + vk_path: PathBuf, + srs_path: PathBuf, + transcript: TranscriptType, + logrows: u32, + check_mode: CheckMode, + split_proofs: bool, +) -> Result { + // the K used for the aggregation circuit + crate::execute::aggregate( + proof_path, + aggregation_snarks, + vk_path, + srs_path, + transcript, + logrows, + check_mode, + split_proofs, + ) + .map_err(|e| { + let err_str = format!("Failed to run aggregate: {}", e); + PyRuntimeError::new_err(err_str) + })?; + + Ok(true) +} + +/// verifies and aggregate proof +#[pyfunction(signature = ( + proof_path, + vk_path, + srs_path, + logrows +))] +fn verify_aggr( + proof_path: PathBuf, + vk_path: PathBuf, + srs_path: PathBuf, + logrows: u32, +) -> Result { + crate::execute::verify_aggr(proof_path, vk_path, srs_path, logrows).map_err(|e| { + let err_str = format!("Failed to run verify_aggr: {}", e); + PyRuntimeError::new_err(err_str) + })?; + + Ok(true) +} + +/// creates an EVM compatible verifier, you will need solc installed in your environment to run this +#[pyfunction(signature = ( + vk_path, + srs_path, + settings_path, + sol_code_path, + abi_path +))] +fn create_evm_verifier( + vk_path: PathBuf, + srs_path: PathBuf, + settings_path: PathBuf, + sol_code_path: PathBuf, + abi_path: PathBuf, +) -> Result { + crate::execute::create_evm_verifier(vk_path, srs_path, settings_path, sol_code_path, abi_path) + .map_err(|e| { + let err_str = format!("Failed to run create_evm_verifier: {}", e); + PyRuntimeError::new_err(err_str) + })?; + + Ok(true) +} + +// creates an EVM compatible data attestation verifier, you will need solc installed in your environment to run this +#[pyfunction(signature = ( + vk_path, + srs_path, + settings_path, + sol_code_path, + abi_path, + input_data +))] +fn create_evm_data_attestation( + vk_path: PathBuf, + srs_path: PathBuf, + settings_path: PathBuf, + sol_code_path: PathBuf, + abi_path: PathBuf, + input_data: PathBuf, +) -> Result { + crate::execute::create_evm_data_attestation( + vk_path, + srs_path, + settings_path, + sol_code_path, + abi_path, + input_data, + ) + .map_err(|e| { + let err_str = format!("Failed to run create_evm_data_attestation: {}", e); + PyRuntimeError::new_err(err_str) + })?; + + Ok(true) +} + +#[pyfunction(signature = ( + addr_path, + sol_code_path, + rpc_url=None, + optimizer_runs=1, + private_key=None +))] +fn deploy_evm( + addr_path: PathBuf, + sol_code_path: PathBuf, + rpc_url: Option, + optimizer_runs: usize, + private_key: Option, +) -> Result { + Runtime::new() + .unwrap() + .block_on(crate::execute::deploy_evm( + sol_code_path, + rpc_url, + addr_path, + optimizer_runs, + private_key, + )) + .map_err(|e| { + let err_str = format!("Failed to run deploy_evm: {}", e); + PyRuntimeError::new_err(err_str) + })?; + + Ok(true) +} + +#[pyfunction(signature = ( + addr_path, + input_data, + settings_path, + sol_code_path, + rpc_url=None, + optimizer_runs=1, + private_key=None +))] +fn deploy_da_evm( + addr_path: PathBuf, + input_data: PathBuf, + settings_path: PathBuf, + sol_code_path: PathBuf, + rpc_url: Option, + optimizer_runs: usize, + private_key: Option, +) -> Result { + Runtime::new() + .unwrap() + .block_on(crate::execute::deploy_da_evm( + input_data, + settings_path, + sol_code_path, + rpc_url, + addr_path, + optimizer_runs, + private_key, + )) + .map_err(|e| { + let err_str = format!("Failed to run deploy_da_evm: {}", e); + PyRuntimeError::new_err(err_str) + })?; + + Ok(true) +} +/// verifies an evm compatible proof, you will need solc installed in your environment to run this +#[pyfunction(signature = ( + proof_path, + addr_verifier, + rpc_url=None, + addr_da = None, +))] +fn verify_evm( + proof_path: PathBuf, + addr_verifier: &str, + rpc_url: Option, + addr_da: Option<&str>, +) -> Result { + let addr_verifier = H160::from_str(addr_verifier).map_err(|e| { + let err_str = format!("address is invalid: {}", e); + PyRuntimeError::new_err(err_str) + })?; + let addr_da = if let Some(addr_da) = addr_da { + let addr_da = H160::from_str(addr_da).map_err(|e| { + let err_str = format!("address is invalid: {}", e); + PyRuntimeError::new_err(err_str) + })?; + Some(addr_da) + } else { + None + }; + + Runtime::new() + .unwrap() + .block_on(crate::execute::verify_evm( + proof_path, + addr_verifier, + rpc_url, + addr_da, + )) + .map_err(|e| { + let err_str = format!("Failed to run verify_evm: {}", e); + PyRuntimeError::new_err(err_str) + })?; + + Ok(true) +} + +/// creates an evm compatible aggregate verifier, you will need solc installed in your environment to run this +#[pyfunction(signature = ( + vk_path, + srs_path, + sol_code_path, + abi_path, + aggregation_settings +))] +fn create_evm_verifier_aggr( + vk_path: PathBuf, + srs_path: PathBuf, + sol_code_path: PathBuf, + abi_path: PathBuf, + aggregation_settings: Vec, +) -> Result { + crate::execute::create_evm_aggregate_verifier( + vk_path, + srs_path, + sol_code_path, + abi_path, + aggregation_settings, + ) + .map_err(|e| { + let err_str = format!("Failed to run create_evm_verifier_aggr: {}", e); + PyRuntimeError::new_err(err_str) + })?; + Ok(true) +} + +/// print hex representation of a proof +#[pyfunction(signature = (proof_path))] +fn print_proof_hex(proof_path: PathBuf) -> Result { + let proof = Snark::load::>(&proof_path) + .map_err(|_| PyIOError::new_err("Failed to load proof"))?; + + Ok(hex::encode(proof.proof)) +} + +/// deploys a model to the hub +#[pyfunction(signature = (model, input, name, organization_id, api_key=None,target=None, py_run_args=None, url=None))] +fn create_hub_artifact( + model: PathBuf, + input: PathBuf, + name: String, + organization_id: String, + api_key: Option<&str>, + target: Option, + py_run_args: Option, + url: Option<&str>, +) -> PyResult { + let run_args: RunArgs = py_run_args.unwrap_or_else(PyRunArgs::new).into(); + let target = target.unwrap_or(CalibrationTarget::Resources { + col_overflow: false, + }); + let output = Runtime::new() + .unwrap() + .block_on(crate::execute::deploy_model( + api_key, + url, + &model, + &input, + &name, + &organization_id, + &run_args, + &target, + )) + .map_err(|e| { + let err_str = format!("Failed to deploy model to hub: {}", e); + PyRuntimeError::new_err(err_str) + })?; + Python::with_gil(|py| Ok(output.to_object(py))) +} + +/// gets a deployed model from the hub +#[pyfunction(signature = (id, api_key=None, url=None))] +fn get_hub_artifact(id: &str, api_key: Option<&str>, url: Option<&str>) -> PyResult { + let output = Runtime::new() + .unwrap() + .block_on(crate::execute::get_deployed_model(api_key, url, &id)) + .map_err(|e| { + let err_str = format!("Failed to get model from hub: {}", e); + PyRuntimeError::new_err(err_str) + })?; + Python::with_gil(|py| Ok(output.to_object(py))) +} + +/// Generate a proof on the hub. +#[pyfunction(signature = ( id, input,api_key=None, url=None))] +fn prove_hub( + id: &str, + input: PathBuf, + api_key: Option<&str>, + url: Option<&str>, +) -> PyResult { + let output = Runtime::new() + .unwrap() + .block_on(crate::execute::prove_hub(api_key, url, id, &input)) + .map_err(|e| { + let err_str = format!("Failed to generate proof on hub: {}", e); + PyRuntimeError::new_err(err_str) + })?; + Python::with_gil(|py| Ok(output.to_object(py))) +} + +/// Fetches proof from hub +#[pyfunction(signature = ( id, api_key=None,url=None))] +fn get_hub_proof(id: &str, api_key: Option<&str>, url: Option<&str>) -> PyResult { + let output = Runtime::new() + .unwrap() + .block_on(crate::execute::get_hub_proof(api_key, url, id)) + .map_err(|e| { + let err_str = format!("Failed to get proof from hub: {}", e); + PyRuntimeError::new_err(err_str) + })?; + Python::with_gil(|py| Ok(output.to_object(py))) +} + +/// Gets hub credentials +#[pyfunction(signature = (username,api_key=None, url=None))] +fn get_hub_credentials( + username: &str, + api_key: Option<&str>, + url: Option<&str>, +) -> PyResult { + let output = Runtime::new() + .unwrap() + .block_on(crate::execute::get_hub_credentials(api_key, url, username)) + .map_err(|e| { + let err_str = format!("Failed to get hub credentials: {}", e); + PyRuntimeError::new_err(err_str) + })?; + Python::with_gil(|py| Ok(output.to_object(py))) +} + +// Python Module +#[pymodule] +fn ezkl(_py: Python<'_>, m: &PyModule) -> PyResult<()> { + // NOTE: DeployVerifierEVM and SendProofEVM will be implemented in python in pyezkl + pyo3_log::init(); + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_function(wrap_pyfunction!(vecu64_to_felt, m)?)?; + m.add_function(wrap_pyfunction!(vecu64_to_int, m)?)?; + m.add_function(wrap_pyfunction!(vecu64_to_float, m)?)?; + m.add_function(wrap_pyfunction!(kzg_commit, m)?)?; + m.add_function(wrap_pyfunction!(swap_proof_commitments, m)?)?; + m.add_function(wrap_pyfunction!(poseidon_hash, m)?)?; + m.add_function(wrap_pyfunction!(elgamal_encrypt, m)?)?; + m.add_function(wrap_pyfunction!(elgamal_decrypt, m)?)?; + m.add_function(wrap_pyfunction!(elgamal_gen_random, m)?)?; + m.add_function(wrap_pyfunction!(float_to_vecu64, m)?)?; + m.add_function(wrap_pyfunction!(buffer_to_felts, m)?)?; + m.add_function(wrap_pyfunction!(gen_vk_from_pk_aggr, m)?)?; + m.add_function(wrap_pyfunction!(gen_vk_from_pk_single, m)?)?; + m.add_function(wrap_pyfunction!(table, m)?)?; + m.add_function(wrap_pyfunction!(mock, m)?)?; + m.add_function(wrap_pyfunction!(setup, m)?)?; + m.add_function(wrap_pyfunction!(prove, m)?)?; + m.add_function(wrap_pyfunction!(verify, m)?)?; + m.add_function(wrap_pyfunction!(gen_srs, m)?)?; + m.add_function(wrap_pyfunction!(get_srs, m)?)?; + m.add_function(wrap_pyfunction!(gen_witness, m)?)?; + m.add_function(wrap_pyfunction!(gen_settings, m)?)?; + m.add_function(wrap_pyfunction!(calibrate_settings, m)?)?; + m.add_function(wrap_pyfunction!(aggregate, m)?)?; + m.add_function(wrap_pyfunction!(mock_aggregate, m)?)?; + m.add_function(wrap_pyfunction!(setup_aggregate, m)?)?; + m.add_function(wrap_pyfunction!(compile_circuit, m)?)?; + m.add_function(wrap_pyfunction!(verify_aggr, m)?)?; + m.add_function(wrap_pyfunction!(create_evm_verifier, m)?)?; + m.add_function(wrap_pyfunction!(deploy_evm, m)?)?; + m.add_function(wrap_pyfunction!(deploy_da_evm, m)?)?; + m.add_function(wrap_pyfunction!(verify_evm, m)?)?; + m.add_function(wrap_pyfunction!(print_proof_hex, m)?)?; + m.add_function(wrap_pyfunction!(create_evm_verifier_aggr, m)?)?; + m.add_function(wrap_pyfunction!(create_evm_data_attestation, m)?)?; + m.add_function(wrap_pyfunction!(create_hub_artifact, m)?)?; + m.add_function(wrap_pyfunction!(get_hub_artifact, m)?)?; + m.add_function(wrap_pyfunction!(prove_hub, m)?)?; + m.add_function(wrap_pyfunction!(get_hub_proof, m)?)?; + m.add_function(wrap_pyfunction!(get_hub_credentials, m)?)?; + + Ok(()) +} diff --git a/mnist_ezkl/src/tensor/mod.rs b/mnist_ezkl/src/tensor/mod.rs new file mode 100644 index 0000000..81dc9a6 --- /dev/null +++ b/mnist_ezkl/src/tensor/mod.rs @@ -0,0 +1,1517 @@ +/// Implementations of common operations on tensors. +pub mod ops; +/// A wrapper around a tensor of circuit variables / advices. +pub mod val; +/// A wrapper around a tensor of Halo2 Value types. +pub mod var; + +use halo2curves::ff::PrimeField; +use rayon::{ + prelude::{ + IndexedParallelIterator, IntoParallelRefIterator, IntoParallelRefMutIterator, + ParallelIterator, + }, + slice::ParallelSliceMut, +}; +use serde::{Deserialize, Serialize}; +pub use val::*; +pub use var::*; + +use crate::{ + circuit::utils, + fieldutils::{felt_to_i32, i128_to_felt, i32_to_felt}, + graph::Visibility, +}; + +use halo2_proofs::{ + arithmetic::Field, + circuit::{AssignedCell, Region, Value}, + plonk::{Advice, Assigned, Column, ConstraintSystem, Expression, Fixed, VirtualCells}, + poly::Rotation, +}; +use itertools::Itertools; +use std::cmp::max; +use std::error::Error; +use std::fmt::Debug; +use std::iter::Iterator; +use std::ops::{Add, Deref, DerefMut, Div, Mul, Neg, Range, Sub}; +use thiserror::Error; +/// A wrapper for tensor related errors. +#[derive(Debug, Error)] +pub enum TensorError { + /// Shape mismatch in a operation + #[error("dimension mismatch in tensor op: {0}")] + DimMismatch(String), + /// Shape when instantiating + #[error("dimensionality error when manipulating a tensor")] + DimError, + /// wrong method was called on a tensor-like struct + #[error("wrong method called")] + WrongMethod, + /// Significant bit truncation when instantiating + #[error("Significant bit truncation when instantiating, try lowering the scale")] + SigBitTruncationError, + /// Failed to convert to field element tensor + #[error("Failed to convert to field element tensor")] + FeltError, + /// Table lookup error + #[error("Table lookup error")] + TableLookupError, +} + +/// The (inner) type of tensor elements. +pub trait TensorType: Clone + Debug + 'static { + /// Returns the zero value. + fn zero() -> Option { + None + } + /// Returns the unit value. + fn one() -> Option { + None + } + /// Max operator for ordering values. + fn tmax(&self, _: &Self) -> Option { + None + } +} + +macro_rules! tensor_type { + ($rust_type:ty, $tensor_type:ident, $zero:expr, $one:expr) => { + impl TensorType for $rust_type { + fn zero() -> Option { + Some($zero) + } + fn one() -> Option { + Some($one) + } + + fn tmax(&self, other: &Self) -> Option { + Some(max(*self, *other)) + } + } + }; +} + +impl TensorType for f32 { + fn zero() -> Option { + Some(0.0) + } + + // f32 doesnt impl Ord so we cant just use max like we can for i32, usize. + // A comparison between f32s needs to handle NAN values. + fn tmax(&self, other: &Self) -> Option { + match (self.is_nan(), other.is_nan()) { + (true, true) => Some(f32::NAN), + (true, false) => Some(*other), + (false, true) => Some(*self), + (false, false) => { + if self >= other { + Some(*self) + } else { + Some(*other) + } + } + } + } +} + +impl TensorType for f64 { + fn zero() -> Option { + Some(0.0) + } + + // f32 doesnt impl Ord so we cant just use max like we can for i32, usize. + // A comparison between f32s needs to handle NAN values. + fn tmax(&self, other: &Self) -> Option { + match (self.is_nan(), other.is_nan()) { + (true, true) => Some(f64::NAN), + (true, false) => Some(*other), + (false, true) => Some(*self), + (false, false) => { + if self >= other { + Some(*self) + } else { + Some(*other) + } + } + } + } +} + +tensor_type!(bool, Bool, false, true); +tensor_type!(i128, Int128, 0, 1); +tensor_type!(i32, Int32, 0, 1); +tensor_type!(usize, USize, 0, 1); +tensor_type!((), Empty, (), ()); +tensor_type!(utils::F32, F32, utils::F32(0.0), utils::F32(1.0)); + +impl TensorType for Tensor { + fn zero() -> Option { + Some(Tensor::new(Some(&[T::zero().unwrap()]), &[1]).unwrap()) + } + fn one() -> Option { + Some(Tensor::new(Some(&[T::one().unwrap()]), &[1]).unwrap()) + } +} + +impl TensorType for Value { + fn zero() -> Option { + Some(Value::known(T::zero().unwrap())) + } + + fn one() -> Option { + Some(Value::known(T::one().unwrap())) + } + + fn tmax(&self, other: &Self) -> Option { + Some( + (self.clone()) + .zip(other.clone()) + .map(|(a, b)| a.tmax(&b).unwrap()), + ) + } +} + +impl TensorType for Assigned +where + F: Field, +{ + fn zero() -> Option { + Some(F::ZERO.into()) + } + + fn one() -> Option { + Some(F::ONE.into()) + } + + fn tmax(&self, other: &Self) -> Option { + if self.evaluate() >= other.evaluate() { + Some(*self) + } else { + Some(*other) + } + } +} + +impl TensorType for Expression +where + F: Field, +{ + fn zero() -> Option { + Some(Expression::Constant(F::ZERO)) + } + + fn one() -> Option { + Some(Expression::Constant(F::ONE)) + } + + fn tmax(&self, _: &Self) -> Option { + todo!() + } +} + +impl TensorType for Column {} +impl TensorType for Column {} + +impl TensorType for AssignedCell, F> { + fn tmax(&self, other: &Self) -> Option { + let mut output: Option = None; + self.value_field().zip(other.value_field()).map(|(a, b)| { + if a.evaluate() >= b.evaluate() { + output = Some(self.clone()); + } else { + output = Some(other.clone()); + } + }); + output + } +} + +impl TensorType for AssignedCell { + fn tmax(&self, other: &Self) -> Option { + let mut output: Option = None; + self.value().zip(other.value()).map(|(a, b)| { + if a >= b { + output = Some(self.clone()); + } else { + output = Some(other.clone()); + } + }); + output + } +} + +// specific types +impl TensorType for halo2curves::pasta::Fp { + fn zero() -> Option { + Some(halo2curves::pasta::Fp::zero()) + } + + fn one() -> Option { + Some(halo2curves::pasta::Fp::one()) + } + + fn tmax(&self, other: &Self) -> Option { + Some((*self).max(*other)) + } +} + +impl TensorType for halo2curves::bn256::Fr { + fn zero() -> Option { + Some(halo2curves::bn256::Fr::zero()) + } + + fn one() -> Option { + Some(halo2curves::bn256::Fr::one()) + } + + fn tmax(&self, other: &Self) -> Option { + Some((*self).max(*other)) + } +} + +/// A generic multi-dimensional array representation of a Tensor. +/// The `inner` attribute contains a vector of values whereas `dims` corresponds to the dimensionality of the array +/// and as such determines how we index, query for values, or slice a Tensor. +#[derive(Clone, Debug, Eq, Serialize, Deserialize, PartialOrd, Ord)] +pub struct Tensor { + inner: Vec, + dims: Vec, + scale: Option, + visibility: Option, +} + +impl IntoIterator for Tensor { + type Item = T; + type IntoIter = ::std::vec::IntoIter; + + fn into_iter(self) -> Self::IntoIter { + self.inner.into_iter() + } +} + +impl Deref for Tensor { + type Target = [T]; + + #[inline] + fn deref(&self) -> &[T] { + self.inner.deref() + } +} + +impl DerefMut for Tensor { + #[inline] + fn deref_mut(&mut self) -> &mut [T] { + self.inner.deref_mut() + } +} + +impl PartialEq for Tensor { + fn eq(&self, other: &Tensor) -> bool { + self.dims == other.dims && self.deref() == other.deref() + } +} + +impl From for Tensor +where + I::Item: TensorType + Clone, + Vec: FromIterator, +{ + fn from(value: I) -> Tensor { + let data: Vec = value.collect::>(); + Tensor::new(Some(&data), &[data.len()]).unwrap() + } +} + +impl FromIterator for Tensor +where + T: TensorType + Clone, + Vec: FromIterator, +{ + fn from_iter>(value: I) -> Tensor { + let data: Vec = value.into_iter().collect::>(); + Tensor::new(Some(&data), &[data.len()]).unwrap() + } +} + +impl From, F>>> + for Tensor +{ + fn from(value: Tensor, F>>) -> Tensor { + let mut output = Vec::new(); + value.map(|x| { + x.evaluate().value().map(|y| { + let e = felt_to_i32(*y); + output.push(e); + e + }) + }); + Tensor::new(Some(&output), value.dims()).unwrap() + } +} + +impl From>> + for Tensor +{ + fn from(value: Tensor>) -> Tensor { + let mut output = Vec::new(); + value.map(|x| { + let mut i = 0; + x.value().map(|y| { + let e = felt_to_i32(*y); + output.push(e); + i += 1; + }); + if i == 0 { + output.push(0); + } + }); + Tensor::new(Some(&output), value.dims()).unwrap() + } +} + +impl From, F>>> + for Tensor> +{ + fn from(value: Tensor, F>>) -> Tensor> { + let mut output = Vec::new(); + for (_, x) in value.iter().enumerate() { + output.push(x.value_field().evaluate()); + } + Tensor::new(Some(&output), value.dims()).unwrap() + } +} + +impl From>> for Tensor { + fn from(t: Tensor>) -> Tensor { + let mut output = Vec::new(); + t.map(|x| { + let mut i = 0; + x.map(|y| { + let e = felt_to_i32(y); + output.push(e); + i += 1; + }); + if i == 0 { + output.push(0); + } + }); + Tensor::new(Some(&output), t.dims()).unwrap() + } +} + +impl From>> + for Tensor>> +{ + fn from(t: Tensor>) -> Tensor>> { + let mut ta: Tensor>> = Tensor::from((0..t.len()).map(|i| t[i].into())); + // safe to unwrap as we know the dims are correct + ta.reshape(t.dims()).unwrap(); + ta + } +} + +impl From> for Tensor> { + fn from(t: Tensor) -> Tensor> { + let mut ta: Tensor> = + Tensor::from((0..t.len()).map(|i| Value::known(i32_to_felt::(t[i])))); + // safe to unwrap as we know the dims are correct + ta.reshape(t.dims()).unwrap(); + ta + } +} + +impl From> for Tensor> { + fn from(t: Tensor) -> Tensor> { + let mut ta: Tensor> = + Tensor::from((0..t.len()).map(|i| Value::known(i128_to_felt::(t[i])))); + // safe to unwrap as we know the dims are correct + ta.reshape(t.dims()).unwrap(); + ta + } +} + +impl + rayon::iter::IntoParallelIterator for Tensor +{ + type Iter = rayon::vec::IntoIter; + type Item = T; + fn into_par_iter(self) -> Self::Iter { + self.inner.into_par_iter() + } +} + +impl<'data, T: Clone + TensorType + std::marker::Send + std::marker::Sync> + rayon::iter::IntoParallelRefMutIterator<'data> for Tensor +{ + type Iter = rayon::slice::IterMut<'data, T>; + type Item = &'data mut T; + fn par_iter_mut(&'data mut self) -> Self::Iter { + self.inner.par_iter_mut() + } +} + +impl Tensor { + /// Sets (copies) the tensor values to the provided ones. + pub fn new(values: Option<&[T]>, dims: &[usize]) -> Result { + let total_dims: usize = if !dims.is_empty() { + dims.iter().product() + } else if values.is_some() { + 1 + } else { + 0 + }; + match values { + Some(v) => { + if total_dims != v.len() { + return Err(TensorError::DimError); + } + Ok(Tensor { + inner: Vec::from(v), + dims: Vec::from(dims), + scale: None, + visibility: None, + }) + } + None => Ok(Tensor { + inner: vec![T::zero().unwrap(); total_dims], + dims: Vec::from(dims), + scale: None, + visibility: None, + }), + } + } + + /// set the tensor's (optional) scale parameter + pub fn set_scale(&mut self, scale: crate::Scale) { + self.scale = Some(scale) + } + + /// set the tensor's (optional) visibility parameter + pub fn set_visibility(&mut self, visibility: &Visibility) { + self.visibility = Some(visibility.clone()) + } + + /// getter for scale + pub fn scale(&self) -> Option { + self.scale + } + + /// getter for visibility + pub fn visibility(&self) -> Option { + self.visibility.clone() + } + + /// Returns the number of elements in the tensor. + pub fn len(&self) -> usize { + self.dims().iter().product::() + } + /// Checks if the number of elements in tensor is 0. + pub fn is_empty(&self) -> bool { + self.inner.len() == 0 + } + + /// Checks if the number of elements in tensor is 1 but with an empty dimension (this is for onnx compatibility). + pub fn is_singleton(&self) -> bool { + self.dims().is_empty() && self.len() == 1 + } + + /// Set one single value on the tensor. + /// + /// ``` + /// use ezkl::tensor::Tensor; + /// let mut a = Tensor::::new(None, &[3, 3, 3]).unwrap(); + /// + /// a.set(&[0, 0, 1], 10); + /// assert_eq!(a[0 + 0 + 1], 10); + /// + /// a.set(&[2, 2, 0], 9); + /// assert_eq!(a[2*9 + 2*3 + 0], 9); + /// ``` + pub fn set(&mut self, indices: &[usize], value: T) { + let index = self.get_index(indices); + self[index] = value; + } + + /// Get a single value from the Tensor. + /// + /// ``` + /// use ezkl::tensor::Tensor; + /// let mut a = Tensor::::new(None, &[2, 3, 5]).unwrap(); + /// + /// a[1*15 + 1*5 + 1] = 5; + /// assert_eq!(a.get(&[1, 1, 1]), 5); + /// ``` + pub fn get(&self, indices: &[usize]) -> T { + let index = self.get_index(indices); + self[index].clone() + } + + /// Get a mutable array index from rows / columns indices. + /// + /// ``` + /// use ezkl::tensor::Tensor; + /// let mut a = Tensor::::new(None, &[2, 3, 5]).unwrap(); + /// + /// a[1*15 + 1*5 + 1] = 5; + /// assert_eq!(a.get(&[1, 1, 1]), 5); + /// ``` + pub fn get_mut(&mut self, indices: &[usize]) -> &mut T { + assert_eq!(self.dims.len(), indices.len()); + let mut index = 0; + let mut d = 1; + for i in (0..indices.len()).rev() { + assert!(self.dims[i] > indices[i]); + index += indices[i] * d; + d *= self.dims[i]; + } + &mut self[index] + } + + /// Pad to a length that is divisible by n + /// ``` + /// use ezkl::tensor::Tensor; + /// let mut a = Tensor::::new(Some(&[1,2,3,4,5,6]), &[2, 3]).unwrap(); + /// let expected = Tensor::::new(Some(&[1, 2, 3, 4, 5, 6, 0, 0]), &[8]).unwrap(); + /// assert_eq!(a.pad_to_zero_rem(4).unwrap(), expected); + /// + /// let expected = Tensor::::new(Some(&[1, 2, 3, 4, 5, 6, 0, 0, 0]), &[9]).unwrap(); + /// assert_eq!(a.pad_to_zero_rem(9).unwrap(), expected); + /// ``` + pub fn pad_to_zero_rem(&self, n: usize) -> Result, TensorError> { + let mut inner = self.inner.clone(); + let remainder = self.len() % n; + if remainder != 0 { + inner.resize(self.len() + n - remainder, T::zero().unwrap()); + } + Tensor::new(Some(&inner), &[inner.len()]) + } + + /// Get a single value from the Tensor. + /// + /// ``` + /// use ezkl::tensor::Tensor; + /// let mut a = Tensor::::new(None, &[2, 3, 5]).unwrap(); + /// + /// let flat_index = 1*15 + 1*5 + 1; + /// a[1*15 + 1*5 + 1] = 5; + /// assert_eq!(a.get_flat_index(flat_index), 5); + /// ``` + pub fn get_flat_index(&self, index: usize) -> T { + self[index].clone() + } + + /// Display a tensor + pub fn show(&self) -> String { + if self.len() > 12 { + let start = self[..12].to_vec(); + // print the two split by ... in the middle + format!( + "[{} ...]", + start.iter().map(|x| format!("{:?}", x)).join(", "), + ) + } else { + format!("[{:?}]", self.iter().map(|x| format!("{:?}", x)).join(", ")) + } + } + + /// Get a slice from the Tensor. + /// ``` + /// use ezkl::tensor::Tensor; + /// let mut a = Tensor::::new(Some(&[1, 2, 3]), &[3]).unwrap(); + /// let mut b = Tensor::::new(Some(&[1, 2]), &[2]).unwrap(); + /// + /// assert_eq!(a.get_slice(&[0..2]).unwrap(), b); + /// ``` + pub fn get_slice(&self, indices: &[Range]) -> Result, TensorError> + where + T: Send + Sync, + { + if indices.is_empty() { + return Ok(self.clone()); + } + if self.dims.len() < indices.len() { + return Err(TensorError::DimError); + } else if indices.iter().map(|x| x.end - x.start).collect::>() == self.dims { + // else if slice is the same as dims, return self + return Ok(self.clone()); + } + + // if indices weren't specified we fill them in as required + let mut full_indices = indices.to_vec(); + + for i in 0..(self.dims.len() - indices.len()) { + full_indices.push(0..self.dims()[indices.len() + i]) + } + + let cartesian_coord: Vec> = full_indices + .iter() + .cloned() + .multi_cartesian_product() + .collect(); + + let res: Vec = cartesian_coord + .par_iter() + .map(|e| { + let index = self.get_index(e); + self[index].clone() + }) + .collect(); + + let dims: Vec = full_indices.iter().map(|e| e.end - e.start).collect(); + + Tensor::new(Some(&res), &dims) + } + + /// Get the array index from rows / columns indices. + /// + /// ``` + /// use ezkl::tensor::Tensor; + /// let a = Tensor::::new(None, &[3, 3, 3]).unwrap(); + /// + /// assert_eq!(a.get_index(&[2, 2, 2]), 26); + /// assert_eq!(a.get_index(&[1, 2, 2]), 17); + /// assert_eq!(a.get_index(&[1, 2, 0]), 15); + /// assert_eq!(a.get_index(&[1, 0, 1]), 10); + /// ``` + pub fn get_index(&self, indices: &[usize]) -> usize { + assert_eq!(self.dims.len(), indices.len()); + let mut index = 0; + let mut d = 1; + for i in (0..indices.len()).rev() { + assert!(self.dims[i] > indices[i]); + index += indices[i] * d; + d *= self.dims[i]; + } + index + } + + /// Duplicates every nth element + /// + /// ``` + /// use ezkl::tensor::Tensor; + /// let a = Tensor::::new(Some(&[1, 2, 3, 4, 5, 6]), &[6]).unwrap(); + /// let expected = Tensor::::new(Some(&[1, 2, 3, 3, 4, 5, 5, 6]), &[8]).unwrap(); + /// assert_eq!(a.duplicate_every_n(3, 1, 0).unwrap(), expected); + /// assert_eq!(a.duplicate_every_n(7, 1, 0).unwrap(), a); + /// + /// let expected = Tensor::::new(Some(&[1, 1, 2, 3, 3, 4, 5, 5, 6]), &[9]).unwrap(); + /// assert_eq!(a.duplicate_every_n(3, 1, 2).unwrap(), expected); + /// + /// ``` + pub fn duplicate_every_n( + &self, + n: usize, + num_repeats: usize, + initial_offset: usize, + ) -> Result, TensorError> { + let mut inner: Vec = vec![]; + let mut offset = initial_offset; + for (i, elem) in self.inner.clone().into_iter().enumerate() { + if (i + offset + 1) % n == 0 { + inner.extend(vec![elem; 1 + num_repeats]); + offset += num_repeats; + } else { + inner.push(elem.clone()); + } + } + Tensor::new(Some(&inner), &[inner.len()]) + } + + /// Removes every nth element + /// + /// ``` + /// use ezkl::tensor::Tensor; + /// let a = Tensor::::new(Some(&[1, 2, 3, 3, 4, 5, 6, 6]), &[8]).unwrap(); + /// let expected = Tensor::::new(Some(&[1, 2, 3, 3, 5, 6, 6]), &[7]).unwrap(); + /// assert_eq!(a.remove_every_n(4, 1, 0).unwrap(), expected); + /// + /// + pub fn remove_every_n( + &self, + n: usize, + num_repeats: usize, + initial_offset: usize, + ) -> Result, TensorError> { + let mut inner: Vec = vec![]; + let mut indices_to_remove = std::collections::HashSet::new(); + for i in 0..self.inner.len() { + if (i + initial_offset + 1) % n == 0 { + for j in 1..(1 + num_repeats) { + indices_to_remove.insert(i + j); + } + } + } + + let old_inner = self.inner.clone(); + for (i, elem) in old_inner.into_iter().enumerate() { + if !indices_to_remove.contains(&i) { + inner.push(elem.clone()); + } + } + + Tensor::new(Some(&inner), &[inner.len()]) + } + + /// Remove indices + /// WARN: assumes indices are in ascending order for speed + /// ``` + /// use ezkl::tensor::Tensor; + /// let a = Tensor::::new(Some(&[1, 2, 3, 4, 5, 6]), &[6]).unwrap(); + /// let expected = Tensor::::new(Some(&[1, 2, 3, 6]), &[4]).unwrap(); + /// let mut indices = vec![3, 4]; + /// assert_eq!(a.remove_indices(&mut indices, true).unwrap(), expected); + /// + /// + /// let a = Tensor::::new(Some(&[52, -245, 153, 13, -4, -56, -163, 249, -128, -172, 396, 143, 2, -96, 504, -44, -158, -393, 61, 95, 191, 74, 64, -219, 553, 104, 235, 222, 44, -216, 63, -251, 40, -140, 112, -355, 60, 123, 26, -116, -89, -200, -109, 168, 135, -34, -99, -54, 5, -81, 322, 87, 4, -139, 420, 92, -295, -12, 262, -1, 26, -48, 231, 1, -335, 244, 188, -4, 5, -362, 57, -198, -184, -117, 40, 305, 49, 30, -59, -26, -37, 96]), &[82]).unwrap(); + /// let b = Tensor::::new(Some(&[52, -245, 153, 13, -4, -56, -163, 249, -128, -172, 396, 143, 2, -96, 504, -44, -158, -393, 61, 95, 191, 74, 64, -219, 553, 104, 235, 222, 44, -216, 63, -251, 40, -140, 112, -355, 60, 123, 26, -116, -89, -200, -109, 168, 135, -34, -99, -54, 5, -81, 322, 87, 4, -139, 420, 92, -295, -12, 262, -1, 26, -48, 231, -335, 244, 188, 5, -362, 57, -198, -184, -117, 40, 305, 49, 30, -59, -26, -37, 96]), &[80]).unwrap(); + /// let mut indices = vec![63, 67]; + /// assert_eq!(a.remove_indices(&mut indices, true).unwrap(), b); + /// ``` + pub fn remove_indices( + &self, + indices: &mut [usize], + is_sorted: bool, + ) -> Result, TensorError> { + let mut inner: Vec = self.inner.clone(); + // time it + if !is_sorted { + indices.par_sort_unstable(); + } + // remove indices + for elem in indices.iter().rev() { + inner.remove(*elem); + } + + Tensor::new(Some(&inner), &[inner.len()]) + } + + /// Returns the tensor's dimensions. + pub fn dims(&self) -> &[usize] { + &self.dims + } + + ///Reshape the tensor + /// ``` + /// use ezkl::tensor::Tensor; + /// let mut a = Tensor::::new(None, &[3, 3, 3]).unwrap(); + /// a.reshape(&[9, 3]); + /// assert_eq!(a.dims(), &[9, 3]); + /// ``` + pub fn reshape(&mut self, new_dims: &[usize]) -> Result<(), TensorError> { + // in onnx parlance this corresponds to converting a tensor to a single element + if new_dims.is_empty() { + if !(self.len() == 1 || self.is_empty()) { + return Err(TensorError::DimError); + } + self.dims = vec![]; + } else { + let product = if new_dims != [0] { + new_dims.iter().product::() + } else { + 0 + }; + if self.len() != product { + return Err(TensorError::DimError); + } + self.dims = Vec::from(new_dims); + } + Ok(()) + } + + /// Move axis of the tensor + /// ``` + /// use ezkl::tensor::Tensor; + /// let mut a = Tensor::::new(None, &[3, 3, 3]).unwrap(); + /// let b = a.move_axis(0, 2).unwrap(); + /// assert_eq!(b.dims(), &[3, 3, 3]); + /// + /// let mut a = Tensor::::new(Some(&[1, 2, 3, 4, 5, 6]), &[3, 1, 2]).unwrap(); + /// let mut expected = Tensor::::new(Some(&[1, 3, 5, 2, 4, 6]), &[1, 2, 3]).unwrap(); + /// let b = a.move_axis(0, 2).unwrap(); + /// assert_eq!(b, expected); + /// + /// let mut a = Tensor::::new(Some(&[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]), &[2, 3, 2]).unwrap(); + /// let mut expected = Tensor::::new(Some(&[1, 3, 5, 2, 4, 6, 7, 9, 11, 8, 10, 12]), &[2, 2, 3]).unwrap(); + /// let b = a.move_axis(1, 2).unwrap(); + /// assert_eq!(b, expected); + /// + /// let mut a = Tensor::::new(Some(&[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]), &[2, 3, 2]).unwrap(); + /// let mut expected = Tensor::::new(Some(&[1, 3, 5, 2, 4, 6, 7, 9, 11, 8, 10, 12]), &[2, 2, 3]).unwrap(); + /// let b = a.move_axis(2, 1).unwrap(); + /// assert_eq!(b, expected); + /// ``` + pub fn move_axis(&mut self, source: usize, destination: usize) -> Result { + assert!(source < self.dims.len()); + assert!(destination < self.dims.len()); + let mut new_dims = self.dims.clone(); + new_dims.remove(source); + new_dims.insert(destination, self.dims[source]); + + // now reconfigure the elements appropriately in the new array + // eg. if we have a 3x3x3 array and we want to move the 0th axis to the 2nd position + // we need to move the elements at 0, 1, 2, 3, 4, 5, 6, 7, 8 to 0, 3, 6, 1, 4, 7, 2, 5, 8 + // so we need to move the elements at 0, 1, 2 to 0, 3, 6 + // and the elements at 3, 4, 5 to 1, 4, 7 + // and the elements at 6, 7, 8 to 2, 5, 8 + let cartesian_coords = new_dims + .iter() + .map(|d| 0..*d) + .multi_cartesian_product() + .collect::>>(); + + let mut output = Tensor::new(None, &new_dims)?; + + for coord in cartesian_coords { + let mut old_coord = vec![0; self.dims.len()]; + + // now fetch the old index + for (i, c) in coord.iter().enumerate() { + if i == destination { + old_coord[source] = *c; + } else if i == source && source < destination { + old_coord[source + 1] = *c; + } else if i == source && source > destination { + old_coord[source - 1] = *c; + } else if (i < source && source < destination) + || (i < destination && source > destination) + { + old_coord[i] = *c; + } else if i > source && source < destination { + old_coord[i + 1] = *c; + } else if i > destination && source > destination { + old_coord[i - 1] = *c; + } else { + return Err(TensorError::DimError); + } + } + output.set(&coord, self.get(&old_coord)); + } + + Ok(output) + } + + /// Swap axes of the tensor + /// ``` + /// use ezkl::tensor::Tensor; + /// let mut a = Tensor::::new(None, &[3, 3, 3]).unwrap(); + /// let b = a.swap_axes(0, 2).unwrap(); + /// assert_eq!(b.dims(), &[3, 3, 3]); + /// + /// let mut a = Tensor::::new(Some(&[1, 2, 3, 4, 5, 6]), &[3, 1, 2]).unwrap(); + /// let mut expected = Tensor::::new(Some(&[1, 3, 5, 2, 4, 6]), &[2, 1, 3]).unwrap(); + /// let b = a.swap_axes(0, 2).unwrap(); + /// assert_eq!(b, expected); + /// + /// let mut a = Tensor::::new(Some(&[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]), &[2, 3, 2]).unwrap(); + /// let mut expected = Tensor::::new(Some(&[1, 3, 5, 2, 4, 6, 7, 9, 11, 8, 10, 12]), &[2, 2, 3]).unwrap(); + /// let b = a.swap_axes(1, 2).unwrap(); + /// assert_eq!(b, expected); + /// + /// let mut a = Tensor::::new(Some(&[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]), &[2, 3, 2]).unwrap(); + /// let mut expected = Tensor::::new(Some(&[1, 3, 5, 2, 4, 6, 7, 9, 11, 8, 10, 12]), &[2, 2, 3]).unwrap(); + /// let b = a.swap_axes(2, 1).unwrap(); + /// assert_eq!(b, expected); + /// ``` + pub fn swap_axes(&mut self, source: usize, destination: usize) -> Result { + assert!(source < self.dims.len()); + assert!(destination < self.dims.len()); + let mut new_dims = self.dims.clone(); + new_dims[source] = self.dims[destination]; + new_dims[destination] = self.dims[source]; + + // now reconfigure the elements appropriately in the new array + // eg. if we have a 3x3x3 array and we want to move the 0th axis to the 2nd position + // we need to move the elements at 0, 1, 2, 3, 4, 5, 6, 7, 8 to 0, 3, 6, 1, 4, 7, 2, 5, 8 + // so we need to move the elements at 0, 1, 2 to 0, 3, 6 + // and the elements at 3, 4, 5 to 1, 4, 7 + // and the elements at 6, 7, 8 to 2, 5, 8 + let cartesian_coords = new_dims + .iter() + .map(|d| 0..*d) + .multi_cartesian_product() + .collect::>>(); + + let mut output = Tensor::new(None, &new_dims)?; + + for coord in cartesian_coords { + let mut old_coord = vec![0; self.dims.len()]; + + // now fetch the old index + for (i, c) in coord.iter().enumerate() { + if i == destination { + old_coord[source] = *c; + } else if i == source { + old_coord[destination] = *c; + } else { + old_coord[i] = *c; + } + } + output.set(&coord, self.get(&old_coord)); + } + + Ok(output) + } + + /// Broadcasts the tensor to a given shape + /// ``` + /// use ezkl::tensor::Tensor; + /// let mut a = Tensor::::new(Some(&[1, 2, 3]), &[3, 1]).unwrap(); + /// + /// let mut expected = Tensor::::new(Some(&[1, 1, 1, 2, 2, 2, 3, 3, 3]), &[3, 3]).unwrap(); + /// assert_eq!(a.expand(&[3, 3]).unwrap(), expected); + /// + /// ``` + pub fn expand(&self, shape: &[usize]) -> Result { + if !(self.dims().len() <= shape.len()) { + return Err(TensorError::DimError); + } + + if shape == self.dims() { + return Ok(self.clone()); + } + + for d in self.dims() { + if !(shape.contains(d) || *d == 1) { + return Err(TensorError::DimError); + } + } + + let cartesian_coords = shape + .iter() + .map(|d| 0..*d) + .multi_cartesian_product() + .collect::>>(); + + let mut output = Tensor::new(None, shape)?; + + for coord in cartesian_coords { + let mut new_coord = Vec::with_capacity(self.dims().len()); + for (i, c) in coord.iter().enumerate() { + if i < self.dims().len() && self.dims()[i] == 1 { + new_coord.push(0); + } else if i >= self.dims().len() { + // do nothing at this point does not exist in the original tensor + } else { + new_coord.push(*c); + } + } + output.set(&coord, self.get(&new_coord)); + } + + Ok(output) + } + + ///Flatten the tensor shape + /// ``` + /// use ezkl::tensor::Tensor; + /// let mut a = Tensor::::new(None, &[3, 3, 3]).unwrap(); + /// a.flatten(); + /// assert_eq!(a.dims(), &[27]); + /// ``` + pub fn flatten(&mut self) { + if !self.dims().is_empty() && (self.dims() != [0]) { + self.dims = Vec::from([self.dims.iter().product::()]); + } + } + + /// Maps a function to tensors + /// ``` + /// use ezkl::tensor::Tensor; + /// let mut a = Tensor::::new(Some(&[1, 4]), &[2]).unwrap(); + /// let mut c = a.map(|x| i32::pow(x,2)); + /// assert_eq!(c, Tensor::from([1, 16].into_iter())) + /// ``` + pub fn map G, G: TensorType>(&self, mut f: F) -> Tensor { + let mut t = Tensor::from(self.inner.iter().map(|e| f(e.clone()))); + // safe to unwrap as we know the dims are correct + t.reshape(self.dims()).unwrap(); + t + } + + /// Maps a function to tensors and enumerates + /// ``` + /// use ezkl::tensor::{Tensor, TensorError}; + /// let mut a = Tensor::::new(Some(&[1, 4]), &[2]).unwrap(); + /// let mut c = a.enum_map::<_,_,TensorError>(|i, x| Ok(i32::pow(x + i as i32, 2))).unwrap(); + /// assert_eq!(c, Tensor::from([1, 25].into_iter())); + /// ``` + pub fn enum_map Result, G: TensorType, E: Error>( + &self, + mut f: F, + ) -> Result, E> { + let vec: Result, E> = self + .inner + .iter() + .enumerate() + .map(|(i, e)| f(i, e.clone())) + .collect(); + let mut t: Tensor = Tensor::from(vec?.iter().cloned()); + // safe to unwrap as we know the dims are correct + t.reshape(self.dims()).unwrap(); + Ok(t) + } + + /// Maps a function to tensors and enumerates in parallel + /// ``` + /// use ezkl::tensor::{Tensor, TensorError}; + /// let mut a = Tensor::::new(Some(&[1, 4]), &[2]).unwrap(); + /// let mut c = a.par_enum_map::<_,_,TensorError>(|i, x| Ok(i32::pow(x + i as i32, 2))).unwrap(); + /// assert_eq!(c, Tensor::from([1, 25].into_iter())); + /// ``` + pub fn par_enum_map< + F: Fn(usize, T) -> Result + std::marker::Send + std::marker::Sync, + G: TensorType + std::marker::Send + std::marker::Sync, + E: Error + std::marker::Send + std::marker::Sync, + >( + &self, + f: F, + ) -> Result, E> + where + T: std::marker::Send + std::marker::Sync, + { + let vec: Result, E> = self + .inner + .par_iter() + .enumerate() + .map(move |(i, e)| f(i, e.clone())) + .collect(); + let mut t: Tensor = Tensor::from(vec?.iter().cloned()); + // safe to unwrap as we know the dims are correct + t.reshape(self.dims()).unwrap(); + Ok(t) + } + + /// Maps a function to tensors and enumerates in parallel + /// ``` + /// use ezkl::tensor::{Tensor, TensorError}; + /// let mut a = Tensor::::new(Some(&[1, 4]), &[2]).unwrap(); + /// let mut c = a.par_enum_map::<_,_,TensorError>(|i, x| Ok(i32::pow(x + i as i32, 2))).unwrap(); + /// assert_eq!(c, Tensor::from([1, 25].into_iter())); + /// ``` + pub fn par_enum_map_mut_filtered< + F: Fn(usize) -> Result + std::marker::Send + std::marker::Sync, + E: Error + std::marker::Send + std::marker::Sync, + >( + &mut self, + filter_indices: &std::collections::HashSet<&usize>, + f: F, + ) -> Result<(), E> + where + T: std::marker::Send + std::marker::Sync, + { + self.inner + .par_iter_mut() + .enumerate() + .filter(|(i, _)| filter_indices.contains(i)) + .for_each(move |(i, e)| *e = f(i).unwrap()); + Ok(()) + } +} + +impl Tensor> { + /// Flattens a tensor of tensors + /// ``` + /// use ezkl::tensor::Tensor; + /// let mut a = Tensor::::new(Some(&[1, 2, 3, 4, 5, 6]), &[2, 3]).unwrap(); + /// let mut b = Tensor::::new(Some(&[1, 4]), &[2, 1]).unwrap(); + /// let mut c = Tensor::new(Some(&[a,b]), &[2]).unwrap(); + /// let mut d = c.combine().unwrap(); + /// assert_eq!(d.dims(), &[8]); + /// ``` + pub fn combine(&self) -> Result, TensorError> { + let mut dims = 0; + let mut inner = Vec::new(); + for t in self.inner.clone().into_iter() { + dims += t.len(); + inner.extend(t.inner); + } + Tensor::new(Some(&inner), &[dims]) + } +} + +impl + std::marker::Send + std::marker::Sync> Add for Tensor { + type Output = Result, TensorError>; + /// Adds tensors. + /// # Arguments + /// + /// * `self` - Tensor + /// * `rhs` - Tensor + /// # Examples + /// ``` + /// use ezkl::tensor::Tensor; + /// use std::ops::Add; + /// let x = Tensor::::new( + /// Some(&[2, 1, 2, 1, 1, 1]), + /// &[2, 3], + /// ).unwrap(); + /// let k = Tensor::::new( + /// Some(&[2, 3, 2, 1, 1, 1]), + /// &[2, 3], + /// ).unwrap(); + /// let result = x.add(k).unwrap(); + /// let expected = Tensor::::new(Some(&[4, 4, 4, 2, 2, 2]), &[2, 3]).unwrap(); + /// assert_eq!(result, expected); + /// + /// // Now test 1D casting + /// let x = Tensor::::new( + /// Some(&[2, 1, 2, 1, 1, 1]), + /// &[2, 3], + /// ).unwrap(); + /// let k = Tensor::::new( + /// Some(&[2]), + /// &[1]).unwrap(); + /// let result = x.add(k).unwrap(); + /// let expected = Tensor::::new(Some(&[4, 3, 4, 3, 3, 3]), &[2, 3]).unwrap(); + /// assert_eq!(result, expected); + /// + /// + /// // Now test 2D casting + /// let x = Tensor::::new( + /// Some(&[2, 1, 2, 1, 1, 1]), + /// &[2, 3], + /// ).unwrap(); + /// let k = Tensor::::new( + /// Some(&[2, 3]), + /// &[2]).unwrap(); + /// let result = x.add(k).unwrap(); + /// let expected = Tensor::::new(Some(&[4, 3, 4, 4, 4, 4]), &[2, 3]).unwrap(); + /// assert_eq!(result, expected); + /// ``` + fn add(self, rhs: Self) -> Self::Output { + let broadcasted_shape = get_broadcasted_shape(self.dims(), rhs.dims()).unwrap(); + let mut lhs = self.expand(&broadcasted_shape).unwrap(); + let rhs = rhs.expand(&broadcasted_shape).unwrap(); + + lhs.par_iter_mut().zip(rhs).for_each(|(o, r)| { + *o = o.clone() + r; + }); + + Ok(lhs) + } +} + +impl + std::marker::Send + std::marker::Sync> Neg for Tensor { + type Output = Tensor; + /// Negates a tensor. + /// # Arguments + /// * `self` - Tensor + /// # Examples + /// ``` + /// use ezkl::tensor::Tensor; + /// use std::ops::Neg; + /// let x = Tensor::::new( + /// Some(&[2, 1, 2, 1, 1, 1]), + /// &[2, 3], + /// ).unwrap(); + /// let result = x.neg(); + /// let expected = Tensor::::new(Some(&[-2, -1, -2, -1, -1, -1]), &[2, 3]).unwrap(); + /// assert_eq!(result, expected); + /// ``` + fn neg(self) -> Self { + let mut output = self; + output.par_iter_mut().for_each(|x| { + *x = x.clone().neg(); + }); + output + } +} + +impl + std::marker::Send + std::marker::Sync> Sub for Tensor { + type Output = Result, TensorError>; + /// Subtracts tensors. + /// # Arguments + /// + /// * `self` - Tensor + /// * `rhs` - Tensor + /// # Examples + /// ``` + /// use ezkl::tensor::Tensor; + /// use std::ops::Sub; + /// let x = Tensor::::new( + /// Some(&[2, 1, 2, 1, 1, 1]), + /// &[2, 3], + /// ).unwrap(); + /// let k = Tensor::::new( + /// Some(&[2, 3, 2, 1, 1, 1]), + /// &[2, 3], + /// ).unwrap(); + /// let result = x.sub(k).unwrap(); + /// let expected = Tensor::::new(Some(&[0, -2, 0, 0, 0, 0]), &[2, 3]).unwrap(); + /// assert_eq!(result, expected); + /// + /// // Now test 1D sub + /// let x = Tensor::::new( + /// Some(&[2, 1, 2, 1, 1, 1]), + /// &[2, 3], + /// ).unwrap(); + /// let k = Tensor::::new( + /// Some(&[2]), + /// &[1], + /// ).unwrap(); + /// let result = x.sub(k).unwrap(); + /// let expected = Tensor::::new(Some(&[0, -1, 0, -1, -1, -1]), &[2, 3]).unwrap(); + /// assert_eq!(result, expected); + /// + /// // Now test 2D sub + /// let x = Tensor::::new( + /// Some(&[2, 1, 2, 1, 1, 1]), + /// &[2, 3], + /// ).unwrap(); + /// let k = Tensor::::new( + /// Some(&[2, 3]), + /// &[2], + /// ).unwrap(); + /// let result = x.sub(k).unwrap(); + /// let expected = Tensor::::new(Some(&[0, -1, 0, -2, -2, -2]), &[2, 3]).unwrap(); + /// assert_eq!(result, expected); + /// ``` + fn sub(self, rhs: Self) -> Self::Output { + let broadcasted_shape = get_broadcasted_shape(self.dims(), rhs.dims()).unwrap(); + let mut lhs = self.expand(&broadcasted_shape).unwrap(); + let rhs = rhs.expand(&broadcasted_shape).unwrap(); + + lhs.par_iter_mut().zip(rhs).for_each(|(o, r)| { + *o = o.clone() - r; + }); + + Ok(lhs) + } +} + +impl + std::marker::Send + std::marker::Sync> Mul for Tensor { + type Output = Result, TensorError>; + /// Elementwise multiplies tensors. + /// # Arguments + /// + /// * `self` - Tensor + /// * `rhs` - Tensor + /// # Examples + /// ``` + /// use ezkl::tensor::Tensor; + /// use std::ops::Mul; + /// let x = Tensor::::new( + /// Some(&[2, 1, 2, 1, 1, 1]), + /// &[2, 3], + /// ).unwrap(); + /// let k = Tensor::::new( + /// Some(&[2, 3, 2, 1, 1, 1]), + /// &[2, 3], + /// ).unwrap(); + /// let result = x.mul(k).unwrap(); + /// let expected = Tensor::::new(Some(&[4, 3, 4, 1, 1, 1]), &[2, 3]).unwrap(); + /// assert_eq!(result, expected); + /// + /// // Now test 1D mult + /// let x = Tensor::::new( + /// Some(&[2, 1, 2, 1, 1, 1]), + /// &[2, 3], + /// ).unwrap(); + /// let k = Tensor::::new( + /// Some(&[2]), + /// &[1]).unwrap(); + /// let result = x.mul(k).unwrap(); + /// let expected = Tensor::::new(Some(&[4, 2, 4, 2, 2, 2]), &[2, 3]).unwrap(); + /// assert_eq!(result, expected); + /// + /// // Now test 2D mult + /// let x = Tensor::::new( + /// Some(&[2, 1, 2, 1, 1, 1]), + /// &[2, 3], + /// ).unwrap(); + /// let k = Tensor::::new( + /// Some(&[2, 2]), + /// &[2]).unwrap(); + /// let result = x.mul(k).unwrap(); + /// let expected = Tensor::::new(Some(&[4, 2, 4, 2, 2, 2]), &[2, 3]).unwrap(); + /// assert_eq!(result, expected); + /// ``` + fn mul(self, rhs: Self) -> Self::Output { + let broadcasted_shape = get_broadcasted_shape(self.dims(), rhs.dims()).unwrap(); + let mut lhs = self.expand(&broadcasted_shape).unwrap(); + let rhs = rhs.expand(&broadcasted_shape).unwrap(); + + lhs.par_iter_mut().zip(rhs).for_each(|(o, r)| { + *o = o.clone() * r; + }); + + Ok(lhs) + } +} + +impl + std::marker::Send + std::marker::Sync> Tensor { + /// Elementwise raise a tensor to the nth power. + /// # Arguments + /// + /// * `self` - Tensor + /// * `b` - Single value + /// # Examples + /// ``` + /// use ezkl::tensor::Tensor; + /// use std::ops::Mul; + /// let x = Tensor::::new( + /// Some(&[2, 15, 2, 1, 1, 0]), + /// &[2, 3], + /// ).unwrap(); + /// let result = x.pow(3).unwrap(); + /// let expected = Tensor::::new(Some(&[8, 3375, 8, 1, 1, 0]), &[2, 3]).unwrap(); + /// assert_eq!(result, expected); + /// ``` + pub fn pow(&self, mut exp: u32) -> Result { + // calculate value of output + let mut base = self.clone(); + let mut acc = base.map(|_| T::one().unwrap()); + + while exp > 1 { + if (exp & 1) == 1 { + acc = acc.mul(base.clone())?; + } + exp /= 2; + base = base.clone().mul(base)?; + } + + // since exp!=0, finally the exp must be 1. + // Deal with the final bit of the exponent separately, since + // squaring the base afterwards is not necessary and may cause a + // needless overflow. + acc.mul(base) + } +} + +impl + std::marker::Send + std::marker::Sync> Div for Tensor { + type Output = Result, TensorError>; + /// Elementwise divide a tensor with another tensor. + /// # Arguments + /// + /// * `self` - Tensor + /// * `rhs` - Tensor + /// # Examples + /// ``` + /// use ezkl::tensor::Tensor; + /// use std::ops::Div; + /// let x = Tensor::::new( + /// Some(&[4, 1, 4, 1, 1, 4]), + /// &[2, 3], + /// ).unwrap(); + /// let y = Tensor::::new( + /// Some(&[2, 1, 2, 1, 1, 1]), + /// &[2, 3], + /// ).unwrap(); + /// let result = x.div(y).unwrap(); + /// let expected = Tensor::::new(Some(&[2, 1, 2, 1, 1, 4]), &[2, 3]).unwrap(); + /// assert_eq!(result, expected); + /// + /// // test 1D casting + /// let x = Tensor::::new( + /// Some(&[4, 1, 4, 1, 1, 4]), + /// &[2, 3], + /// ).unwrap(); + /// let y = Tensor::::new( + /// Some(&[2]), + /// &[1], + /// ).unwrap(); + /// let result = x.div(y).unwrap(); + /// let expected = Tensor::::new(Some(&[2, 0, 2, 0, 0, 2]), &[2, 3]).unwrap(); + /// assert_eq!(result, expected); + /// ``` + fn div(self, rhs: Self) -> Self::Output { + let broadcasted_shape = get_broadcasted_shape(self.dims(), rhs.dims()).unwrap(); + let mut lhs = self.expand(&broadcasted_shape).unwrap(); + let rhs = rhs.expand(&broadcasted_shape).unwrap(); + + lhs.par_iter_mut().zip(rhs).for_each(|(o, r)| { + *o = o.clone() / r; + }); + + Ok(lhs) + } +} + +/// Returns the broadcasted shape of two tensors +/// ``` +/// use ezkl::tensor::get_broadcasted_shape; +/// let a = vec![2, 3]; +/// let b = vec![2, 3]; +/// let c = get_broadcasted_shape(&a, &b).unwrap(); +/// assert_eq!(c, vec![2, 3]); +/// +/// let a = vec![2, 3]; +/// let b = vec![3]; +/// let c = get_broadcasted_shape(&a, &b).unwrap(); +/// assert_eq!(c, vec![2, 3]); +/// +/// let a = vec![2, 3]; +/// let b = vec![2, 1]; +/// let c = get_broadcasted_shape(&a, &b).unwrap(); +/// assert_eq!(c, vec![2, 3]); +/// +/// let a = vec![2, 3]; +/// let b = vec![1, 3]; +/// let c = get_broadcasted_shape(&a, &b).unwrap(); +/// assert_eq!(c, vec![2, 3]); +/// +/// let a = vec![2, 3]; +/// let b = vec![1, 1]; +/// let c = get_broadcasted_shape(&a, &b).unwrap(); +/// assert_eq!(c, vec![2, 3]); +/// +/// ``` + +pub fn get_broadcasted_shape( + shape_a: &[usize], + shape_b: &[usize], +) -> Result, Box> { + let num_dims_a = shape_a.len(); + let num_dims_b = shape_b.len(); + + // reewrite the below using match + if num_dims_a == num_dims_b { + let mut broadcasted_shape = Vec::with_capacity(num_dims_a); + for (dim_a, dim_b) in shape_a.iter().zip(shape_b.iter()) { + let max_dim = dim_a.max(dim_b); + broadcasted_shape.push(*max_dim); + } + Ok(broadcasted_shape) + } else if num_dims_a < num_dims_b { + Ok(shape_b.to_vec()) + } else { + Ok(shape_a.to_vec()) + } +} +//////////////////////// + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_tensor() { + let data: Vec = vec![-1.0f32, 0.0, 1.0, 2.5]; + let tensor = Tensor::::new(Some(&data), &[2, 2]).unwrap(); + assert_eq!(&tensor[..], &data[..]); + } + + #[test] + fn tensor_clone() { + let x = Tensor::::new(Some(&[1, 2, 3]), &[3]).unwrap(); + assert_eq!(x, x.clone()); + } + + #[test] + fn tensor_eq() { + let a = Tensor::::new(Some(&[1, 2, 3]), &[3]).unwrap(); + let mut b = Tensor::::new(Some(&[1, 2, 3]), &[3, 1]).unwrap(); + b.reshape(&[3]).unwrap(); + let c = Tensor::::new(Some(&[1, 2, 4]), &[3]).unwrap(); + let d = Tensor::::new(Some(&[1, 2, 4]), &[3, 1]).unwrap(); + assert_eq!(a, b); + assert_ne!(a, c); + assert_ne!(a, d); + } + #[test] + fn tensor_slice() { + let a = Tensor::::new(Some(&[1, 2, 3, 4, 5, 6]), &[2, 3]).unwrap(); + let b = Tensor::::new(Some(&[1, 4]), &[2, 1]).unwrap(); + assert_eq!(a.get_slice(&[0..2, 0..1]).unwrap(), b); + } +} diff --git a/mnist_ezkl/src/tensor/ops.rs b/mnist_ezkl/src/tensor/ops.rs new file mode 100644 index 0000000..c0514d1 --- /dev/null +++ b/mnist_ezkl/src/tensor/ops.rs @@ -0,0 +1,4008 @@ +use super::TensorError; +use crate::tensor::{Tensor, TensorType}; +use itertools::Itertools; +use rayon::{ + iter::IndexedParallelIterator, iter::IntoParallelRefMutIterator, iter::ParallelIterator, + prelude::IntoParallelRefIterator, +}; +use std::collections::{HashMap, HashSet}; +pub use std::ops::{Add, Div, Mul, Neg, Sub}; + +/// IFF operation. +/// # Arguments +/// * `mask` - Tensor of 0s and 1s +/// * `a` - Tensor +/// * `b` - Tensor +/// # Examples +/// ``` +/// use ezkl::tensor::Tensor; +/// use ezkl::tensor::ops::iff; +/// let mask = Tensor::::new( +/// Some(&[1, 0, 1, 0, 1, 0]), +/// &[2, 3], +/// ).unwrap(); +/// let a = Tensor::::new( +/// Some(&[1, 2, 3, 4, 5, 6]), +/// &[2, 3], +/// ).unwrap(); +/// let b = Tensor::::new( +/// Some(&[7, 8, 9, 10, 11, 12]), +/// &[2, 3], +/// ).unwrap(); +/// let result = iff(&mask, &a, &b).unwrap(); +/// let expected = Tensor::::new(Some(&[1, 8, 3, 10, 5, 12]), &[2, 3]).unwrap(); +/// assert_eq!(result, expected); +/// ``` +pub fn iff< + T: TensorType + + Add + + Mul + + Sub + + std::marker::Send + + std::marker::Sync + + std::cmp::PartialEq, +>( + mask: &Tensor, + a: &Tensor, + b: &Tensor, +) -> Result, TensorError> { + // assert is boolean + if !mask + .par_iter() + .all(|x| *x == T::one().unwrap() || *x == T::zero().unwrap()) + { + return Err(TensorError::WrongMethod); + } + + let masked_a = (mask.clone() * a.clone())?; + + let masked_b = ((Tensor::from(vec![T::one().ok_or(TensorError::DimError)?].into_iter()) + - mask.clone())? + * b.clone())?; + + masked_a + masked_b +} + +/// Elementwise applies not to a tensor of integers. +/// # Arguments +/// * `a` - Tensor +/// # Examples +/// ``` +/// use ezkl::tensor::Tensor; +/// use ezkl::tensor::ops::not; +/// let x = Tensor::::new( +/// Some(&[1, 1, 1, 1, 1, 0]), +/// &[2, 3], +/// ).unwrap(); +/// let result = not(&x).unwrap(); +/// let expected = Tensor::::new(Some(&[0, 0, 0, 0, 0, 1]), &[2, 3]).unwrap(); +/// assert_eq!(result, expected); +/// ``` +pub fn not< + T: TensorType + + Add + + Mul + + Sub + + std::marker::Send + + std::marker::Sync + + std::cmp::PartialEq, +>( + a: &Tensor, +) -> Result, TensorError> { + iff( + a, + &Tensor::from(vec![T::zero().unwrap()].into_iter()), + &Tensor::from(vec![T::one().unwrap()].into_iter()), + ) +} + +/// Elementwise applies or to two tensors of integers. +/// # Arguments +/// * `a` - Tensor +/// * `b` - Tensor +/// # Examples +/// ``` +/// use ezkl::tensor::Tensor; +/// use ezkl::tensor::ops::or; +/// let a = Tensor::::new( +/// Some(&[1, 1, 1, 1, 1, 0]), +/// &[2, 3], +/// ).unwrap(); +/// let b = Tensor::::new( +/// Some(&[1, 0, 1, 0, 1, 0]), +/// &[2, 3], +/// ).unwrap(); +/// let result = or(&a, &b).unwrap(); +/// let expected = Tensor::::new(Some(&[1, 1, 1, 1, 1, 0]), &[2, 3]).unwrap(); +/// assert_eq!(result, expected); +/// ``` +pub fn or< + T: TensorType + + Add + + Mul + + Sub + + std::marker::Send + + std::marker::Sync + + std::cmp::PartialEq, +>( + a: &Tensor, + b: &Tensor, +) -> Result, TensorError> { + if !b + .par_iter() + .all(|x| *x == T::one().unwrap() || *x == T::zero().unwrap()) + { + return Err(TensorError::WrongMethod); + } + + iff(a, a, b) +} + +/// Elementwise applies xor to two tensors +/// # Arguments +/// * `a` - Tensor +/// * `b` - Tensor +/// # Examples +/// ``` +/// use ezkl::tensor::Tensor; +/// use ezkl::tensor::ops::xor; +/// let a = Tensor::::new( +/// Some(&[1, 1, 1, 1, 1, 0]), +/// &[2, 3], +/// ).unwrap(); +/// let b = Tensor::::new( +/// Some(&[1, 0, 1, 0, 1, 0]), +/// &[2, 3], +/// ).unwrap(); +/// let result = xor(&a, &b).unwrap(); +/// let expected = Tensor::::new(Some(&[0, 1, 0, 1, 0, 0]), &[2, 3]).unwrap(); +/// assert_eq!(result, expected); +/// ``` +/// +pub fn xor< + T: TensorType + + Add + + Mul + + Sub + + std::marker::Send + + std::marker::Sync + + std::cmp::PartialEq, +>( + a: &Tensor, + b: &Tensor, +) -> Result, TensorError> { + let a_not_b = (a.clone() * not(b)?)?; + let b_not_a = (b.clone() * not(a)?)?; + a_not_b + b_not_a +} + +/// Elementwise applies and to two tensors +/// # Arguments +/// * `a` - Tensor +/// * `b` - Tensor +/// # Examples +/// ``` +/// use ezkl::tensor::Tensor; +/// use ezkl::tensor::ops::and; +/// let a = Tensor::::new( +/// Some(&[1, 1, 1, 1, 1, 0]), +/// &[2, 3], +/// ).unwrap(); +/// let b = Tensor::::new( +/// Some(&[1, 0, 1, 0, 1, 0]), +/// &[2, 3], +/// ).unwrap(); +/// let result = and(&a, &b).unwrap(); +/// let expected = Tensor::::new(Some(&[1, 0, 1, 0, 1, 0]), &[2, 3]).unwrap(); +/// assert_eq!(result, expected); +/// ``` +pub fn and< + T: TensorType + + Add + + Mul + + Sub + + std::marker::Send + + std::marker::Sync + + std::cmp::PartialEq, +>( + a: &Tensor, + b: &Tensor, +) -> Result, TensorError> { + // assert is boolean + if !b + .par_iter() + .all(|x| *x == T::one().unwrap() || *x == T::zero().unwrap()) + { + return Err(TensorError::WrongMethod); + } + + // assert is boolean + if !a + .par_iter() + .all(|x| *x == T::one().unwrap() || *x == T::zero().unwrap()) + { + return Err(TensorError::WrongMethod); + } + + a.clone() * b.clone() +} + +/// Elementwise applies equals to two tensors of integers. +/// # Arguments +/// * `a` - Tensor +/// * `b` - Tensor +/// # Examples +/// ``` +/// use ezkl::tensor::Tensor; +/// use ezkl::tensor::ops::equals; +/// let a = Tensor::::new( +/// Some(&[1, 1, 1, 1, 1, 0]), +/// &[2, 3], +/// ).unwrap(); +/// let b = Tensor::::new( +/// Some(&[1, 0, 1, 0, 1, 0]), +/// &[2, 3], +/// ).unwrap(); +/// let result = equals(&a, &b).unwrap().0; +/// let expected = Tensor::::new(Some(&[1, 0, 1, 0, 1, 1]), &[2, 3]).unwrap(); +/// assert_eq!(result, expected); +/// ``` +pub fn equals< + T: TensorType + + std::marker::Send + + std::marker::Sync + + Sub + + Mul + + Add + + std::cmp::PartialEq + + std::cmp::PartialOrd + + std::convert::From, +>( + a: &Tensor, + b: &Tensor, +) -> Result<(Tensor, Vec>), TensorError> { + let a = a.clone(); + let b = b.clone(); + + let diff = (a - b)?; + + let result = nonlinearities::kronecker_delta(&diff); + + Ok((result, vec![diff])) +} + +/// Greater than operation. +/// # Arguments +/// * `a` - Tensor +/// * `b` - Tensor +/// # Examples +/// ``` +/// use ezkl::tensor::Tensor; +/// use ezkl::tensor::ops::greater; +/// let a = Tensor::::new( +/// Some(&[1, 12, 6, 4, 5, 6]), +/// &[2, 3], +/// ).unwrap(); +/// let b = Tensor::::new( +/// Some(&[1, 2, 3, 4, 5, 6]), +/// &[2, 3], +/// ).unwrap(); +/// let result = greater(&a, &b).unwrap(); +/// let expected = Tensor::::new(Some(&[0, 1, 1, 0, 0, 0]), &[2, 3]).unwrap(); +/// assert_eq!(result.0, expected); +/// ``` +pub fn greater< + T: TensorType + + Sub + + Mul + + std::marker::Send + + std::marker::Sync + + std::cmp::PartialOrd + + std::convert::TryFrom, +>( + a: &Tensor, + b: &Tensor, +) -> Result<(Tensor, Vec>), TensorError> { + let mask_inter = (a.clone() - b.clone())?; + let mask = mask_inter.map(|x| { + if x > T::zero().ok_or(TensorError::DimError).unwrap() { + T::one().ok_or(TensorError::DimError).unwrap() + } else { + T::zero().ok_or(TensorError::DimError).unwrap() + } + }); + Ok((mask, vec![mask_inter])) +} + +/// Greater equals than operation. +/// # Arguments +/// * `a` - Tensor +/// * `b` - Tensor +/// # Examples +/// ``` +/// use ezkl::tensor::Tensor; +/// use ezkl::tensor::ops::greater_equal; +/// let a = Tensor::::new( +/// Some(&[1, 12, 6, 4, 3, 2]), +/// &[2, 3], +/// ).unwrap(); +/// let b = Tensor::::new( +/// Some(&[1, 2, 3, 4, 5, 4]), +/// &[2, 3], +/// ).unwrap(); +/// let result = greater_equal(&a, &b).unwrap(); +/// let expected = Tensor::::new(Some(&[1, 1, 1, 1, 0, 0]), &[2, 3]).unwrap(); +/// assert_eq!(result.0, expected); +/// ``` +pub fn greater_equal< + T: TensorType + + Sub + + Mul + + std::marker::Send + + std::marker::Sync + + std::cmp::PartialOrd + + std::convert::TryFrom, +>( + a: &Tensor, + b: &Tensor, +) -> Result<(Tensor, Vec>), TensorError> { + let mask_inter = (a.clone() - b.clone())?; + let mask = mask_inter.map(|x| { + if x >= T::zero().ok_or(TensorError::DimError).unwrap() { + T::one().ok_or(TensorError::DimError).unwrap() + } else { + T::zero().ok_or(TensorError::DimError).unwrap() + } + }); + Ok((mask, vec![mask_inter])) +} + +/// Less than to operation. +/// # Arguments +/// * `a` - Tensor +/// * `b` - Tensor +/// # Examples +/// ``` +/// use ezkl::tensor::Tensor; +/// use ezkl::tensor::ops::less; +/// let a = Tensor::::new( +/// Some(&[1, 0, 5, 4, 5, 1]), +/// &[2, 3], +/// ).unwrap(); +/// let b = Tensor::::new( +/// Some(&[1, 2, 3, 4, 5, 6]), +/// &[2, 3], +/// ).unwrap(); +/// let result = less(&a, &b).unwrap(); +/// let expected = Tensor::::new(Some(&[0, 1, 0, 0, 0, 1]), &[2, 3]).unwrap(); +/// assert_eq!(result.0, expected); +/// ``` +/// +pub fn less< + T: TensorType + + Sub + + Mul + + std::marker::Send + + std::marker::Sync + + std::cmp::PartialOrd + + std::convert::TryFrom, +>( + a: &Tensor, + b: &Tensor, +) -> Result<(Tensor, Vec>), TensorError> { + // a < b <=> b > a + greater(b, a) +} + +/// Less equals than operation. +/// # Arguments +/// * `a` - Tensor +/// * `b` - Tensor +/// # Examples +/// ``` +/// use ezkl::tensor::Tensor; +/// use ezkl::tensor::ops::less_equal; +/// let a = Tensor::::new( +/// Some(&[1, 0, 5, 4, 5, 1]), +/// &[2, 3], +/// ).unwrap(); +/// let b = Tensor::::new( +/// Some(&[1, 2, 3, 4, 5, 6]), +/// &[2, 3], +/// ).unwrap(); +/// let result = less_equal(&a, &b).unwrap(); +/// let expected = Tensor::::new(Some(&[1, 1, 0, 1, 1, 1]), &[2, 3]).unwrap(); +/// assert_eq!(result.0, expected); +/// ``` +/// +pub fn less_equal< + T: TensorType + + Sub + + Mul + + std::marker::Send + + std::marker::Sync + + std::cmp::PartialOrd + + std::convert::TryFrom, +>( + a: &Tensor, + b: &Tensor, +) -> Result<(Tensor, Vec>), TensorError> { + // a < b <=> b > a + greater_equal(b, a) +} + +/// Resize using nearest neighbour interpolation. +/// # Arguments +/// * `a` - Tensor +/// * `scales` - Vector of scales +/// # Examples +/// ``` +/// +/// +/// let a = Tensor::::new( +/// Some(&[1, 2, 3, 4, 5, 6]), +/// &[2, 3], +/// ).unwrap(); +/// let result = resize(&a, &[1, 2]).unwrap(); +/// let expected = Tensor::::new(Some(&[1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6]), &[2, 6]).unwrap(); +/// assert_eq!(result, expected); +/// +/// +/// let a = Tensor::::new( +/// Some(&[1, 2, 3, 4, 5, 6]), +/// &[2, 3], +/// ).unwrap(); +/// let result = resize(&a, &[2, 2]).unwrap(); +/// let expected = Tensor::::new(Some(&[1, 1, 2, 2, 3, 3, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 4, 4, 5, 5, 6, 6]), &[4, 6]).unwrap(); +/// assert_eq!(result, expected); +/// +/// use ezkl::tensor::Tensor; +/// use ezkl::tensor::ops::resize; +/// let a = Tensor::::new( +/// Some(&[1, 2, 3, 4]), +/// &[2, 2], +/// ).unwrap(); +/// let result = resize(&a, &[2, 2]).unwrap(); +/// let expected = Tensor::::new(Some(&[1, 1, 2, 2, 1, 1, 2, 2, 3, 3, 4, 4, 3, 3, 4, 4]), &[4, 4]).unwrap(); +/// assert_eq!(result, expected); +/// +/// +/// let a = Tensor::::new( +/// Some(&[1, 2, 3, 4, 5, 6]), +/// &[3, 2], +/// ).unwrap(); +/// let result = resize(&a, &[2, 3]).unwrap(); +/// let expected = Tensor::::new(Some(&[1, 1, 1, 2, 2, 2, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6, 5, 5, 5, 6, 6, 6]), &[6, 6]).unwrap(); +/// assert_eq!(result, expected); +/// +/// +/// ``` +pub fn resize( + a: &Tensor, + scales: &[usize], +) -> Result, TensorError> { + let mut new_shape = vec![]; + for (s, d) in scales.iter().zip(a.dims()) { + new_shape.push(s * d); + } + + let mut output = Tensor::new(None, &new_shape)?; + + let cartesian_coord: Vec> = new_shape + .iter() + .map(|d| (0..*d)) + .multi_cartesian_product() + .collect(); + + // resize using nearest neighbour interpolation + // (i.e. just copy the value of the nearest neighbour to pad the tensor) + output = output.par_enum_map(|i, _| { + let mut coord = vec![]; + for (j, (c, _d)) in cartesian_coord[i].iter().zip(new_shape.iter()).enumerate() { + let scale = scales[j]; + let fragment = c / scale; + coord.push(fragment); + } + + Ok::<_, TensorError>(a.get(&coord)) + })?; + + Ok(output) +} + +/// Computes the einstein sum of a set of tensors. +/// # Arguments +/// * `equation` - Einstein summation equation +/// * `inputs` - Vector of tensors +/// # Examples +/// ``` +/// use ezkl::tensor::Tensor; +/// use ezkl::tensor::ops::einsum; +/// +/// // matmul case +/// let x = Tensor::::new( +/// Some(&[2, 1, 2, 1, 1, 1]), +/// &[2, 3], +/// ).unwrap(); +/// let k = Tensor::::new( +/// Some(&[2, 3, 2, 1, 1, 1]), +/// &[3, 2], +/// ).unwrap(); +/// let result = einsum("ij,jk->ik", &[x, k]).unwrap(); +/// let expected = Tensor::::new(Some(&[8, 9, 5, 5]), &[2, 2]).unwrap(); +/// assert_eq!(result, expected); +/// +/// // element wise multiplication +/// let x = Tensor::::new( +/// Some(&[1, 2, 3, 2, 3, 4, 3, 4, 5]), +/// &[3, 3], +/// ).unwrap(); +/// let k = Tensor::::new( +/// Some(&[1, 2, 3, 1, 2, 3, 1, 2, 3]), +/// &[3, 3], +/// ).unwrap(); +/// let result = einsum("ij,ij->ij", &[x, k]).unwrap(); +/// let expected = Tensor::::new(Some(&[1, 4, 9, 2, 6, 12, 3, 8, 15]), &[3, 3]).unwrap(); +/// assert_eq!(result, expected); +/// +/// +/// // dot product of A with the transpose of B. +/// let x = Tensor::::new( +/// Some(&[1, 2, 3, 2, 3, 4, 3, 4, 5]), +/// &[3, 3], +/// ).unwrap(); +/// let k = Tensor::::new( +/// Some(&[1, 2, 3, 1, 2, 3, 1, 2, 3]), +/// &[3, 3], +/// ).unwrap(); +/// let result = einsum("ik,jk->ij", &[x, k]).unwrap(); +/// let expected = Tensor::::new(Some(&[14, 14, 14, 20, 20, 20, 26, 26, 26]), &[3, 3]).unwrap(); +/// assert_eq!(result, expected); +/// +/// // dot product +/// let x = Tensor::::new( +/// Some(&[1, 2, 3, 2, 3, 4, 3, 4, 5]), +/// &[3, 3], +/// ).unwrap(); +/// let k = Tensor::::new( +/// Some(&[1, 2, 3, 1, 2, 3, 1, 2, 3]), +/// &[3, 3], +/// ).unwrap(); +/// let result = einsum("ik,ik->i", &[x, k]).unwrap(); +/// let expected = Tensor::::new(Some(&[14, 20, 26]), &[3]).unwrap(); +/// assert_eq!(result, expected); +/// +/// +/// // dot product +/// let x = Tensor::::new( +/// Some(&[1, 2, 3]), +/// &[3], +/// ).unwrap(); +/// let k = Tensor::::new( +/// Some(&[1, 2, 3]), +/// &[3], +/// ).unwrap(); +/// let result = einsum("i,i->", &[x, k]).unwrap(); +/// let expected = Tensor::::new(Some(&[14]), &[1]).unwrap(); +/// assert_eq!(result, expected); +/// +/// +/// // wut ? +/// let x = Tensor::::new( +/// Some(&[1, 2, 3, 2, 3, 4, 3, 4, 5, 1, 2, 3, 2, 3, 4, 3, 4, 5]), +/// &[3, 3, 2], +/// ).unwrap(); +/// let k = Tensor::::new( +/// Some(&[4, 5, 7, 8]), +/// &[2, 2], +/// ).unwrap(); +/// let result = einsum("anm,bm->ba", &[x, k]).unwrap(); +/// let expected = Tensor::::new(Some(&[68, 80, 95, 113, 134, 158]), &[2, 3]).unwrap(); +/// assert_eq!(result, expected); +/// +/// // wutttttt ? +/// let x = Tensor::::new( +/// Some(&[1, 2, 3, 2, 3, 4, 3, 4, 5, 1, 2, 3, 2, 3, 4, 3, 4, 5]), +/// &[3, 3, 2], +/// ).unwrap(); +/// let k = Tensor::::new( +/// Some(&[4, 5, 7, 8]), +/// &[2, 2], +/// ).unwrap(); +/// let z = Tensor::::new( +/// Some(&[4, 5, 7, 8, 9, 9]), +/// &[2, 3], +/// ).unwrap(); +/// +/// let result = einsum("bn,anm,bm->ba", &[z, x, k]).unwrap(); +/// let expected = Tensor::::new(Some(&[390, 414, 534, 994, 1153, 1384]), &[2, 3]).unwrap(); +/// assert_eq!(result, expected); +/// +/// +/// // contraction with a single common axis +/// let x = Tensor::::new( +/// Some(&[1, 2, 3, 2, 3, 4, 3, 4, 5, 1, 2, 3, 2, 3, 4, 3, 4, 5]), +/// &[3, 3, 2], +/// ).unwrap(); +/// let k = Tensor::::new( +/// Some(&[4, 5, 7, 8]), +/// &[2, 2], +/// ).unwrap(); +/// let result = einsum("abc,cd->", &[x, k]).unwrap(); +/// let expected = Tensor::::new(Some(&[648]), &[1]).unwrap(); +/// assert_eq!(result, expected); +/// +/// // contraction with no common axes (outer product) +/// let x = Tensor::::new( +/// Some(&[1, 2, 3, 2, 3, 4, 3, 4, 5, 1, 2, 3, 2, 3, 4, 3, 4, 5]), +/// &[3, 3, 2], +/// ).unwrap(); +/// let k = Tensor::::new( +/// Some(&[4, 5, 7, 8]), +/// &[2, 2], +/// ).unwrap(); +/// let result = einsum("abc,ed->", &[x, k]).unwrap(); +/// let expected = Tensor::::new(Some(&[1296]), &[1]).unwrap(); +/// assert_eq!(result, expected); +/// +/// // trivial axes mapping +/// let x = Tensor::::new( +/// Some(&[4, 5, 7, 8]), +/// &[2, 2], +/// ).unwrap(); +/// let k = Tensor::::new( +/// Some(&[4, 5]), +/// &[2], +/// ).unwrap(); +/// +/// let result = einsum("mk,k->m", &[x.clone(), k.clone()]).unwrap(); +/// let expected = Tensor::::new(Some(&[41, 68]), &[2]).unwrap(); +/// assert_eq!(result, expected); +/// +/// let result = einsum("mk,k->mn", &[x, k]).unwrap(); +/// let expected = Tensor::::new(Some(&[41, 68]), &[2, 1]).unwrap(); +/// assert_eq!(result, expected); +/// +/// ``` +pub fn einsum< + T: TensorType + Mul + Add + std::marker::Send + std::marker::Sync, +>( + equation: &str, + inputs: &[Tensor], +) -> Result, TensorError> { + // Parse equation into an operation + + let mut equation = equation.split("->"); + let inputs_eq = equation.next().unwrap(); + let output_eq = equation.next().unwrap(); + let inputs_eq = inputs_eq.split(',').collect::>(); + + // Check that the number of inputs matches the number of inputs in the equation + if inputs.len() != inputs_eq.len() { + return Err(TensorError::DimMismatch("einsum".to_string())); + } + + let mut indices_to_size = HashMap::new(); + for (i, input) in inputs.iter().enumerate() { + for j in 0..inputs_eq[i].len() { + let c = inputs_eq[i].chars().nth(j).unwrap(); + if let std::collections::hash_map::Entry::Vacant(e) = indices_to_size.entry(c) { + e.insert(input.dims()[j]); + } else if indices_to_size[&c] != input.dims()[j] { + return Err(TensorError::DimMismatch("einsum".to_string())); + } + } + } + + // maps unrepresented indices in the output to a trivial 1 + for c in output_eq.chars() { + indices_to_size.entry(c).or_insert(1); + } + + // Compute the output tensor shape + let mut output_shape: Vec = output_eq + .chars() + .map(|c| *indices_to_size.get(&c).unwrap()) + .collect(); + + if output_shape.is_empty() { + output_shape.push(1); + } + + let mut seen = HashSet::new(); + let mut common_indices_to_inputs = vec![]; + for input in &inputs_eq { + for c in input.chars() { + if !seen.contains(&c) { + seen.insert(c); + } else { + common_indices_to_inputs.push(c); + } + } + } + + let cartesian_coord = output_shape + .iter() + .map(|d| 0..*d) + .multi_cartesian_product() + .collect::>(); + + // Compute the cartesian product of all indices + let output: Vec = cartesian_coord + .par_iter() + .map(|coord| { + // Compute the slice of each input tensor given the current coordinate of the output tensor + let inputs = (0..inputs.len()) + .map(|idx| { + let mut slice = vec![]; + for (i, c) in inputs_eq[idx].chars().enumerate() { + // If the current index is in the output equation, then the slice should be the current coordinate + if let Some(idx) = output_eq.find(c) { + slice.push(coord[idx]..coord[idx] + 1); + // Otherwise, the slice should be the entire dimension of the input tensor + } else { + slice.push(0..inputs[idx].dims()[i]); + } + } + + // Get the slice of the input tensor + inputs[idx].get_slice(&slice).unwrap() + }) + .collect::>(); + + // Get the indices common accross input tensors + let mut common_coord = common_indices_to_inputs + .iter() + .map(|d| { + // If the current index is in the output equation, then the slice should be the current coordinate + if output_eq.contains(*d) { + 0..1 + // Otherwise, the slice should be the entire dimension of the input tensor + } else { + 0..*indices_to_size.get(d).unwrap() + } + }) + .multi_cartesian_product() + .collect::>(); + + // If there are no common indices, then we need to add an empty slice to force one iteration of the loop + if common_coord.is_empty() { + common_coord.push(vec![]); + } + + let mut prod = T::zero().unwrap(); + + // Compute the cartesian product of all common indices + for common_dim in common_coord { + let inputs = (0..inputs.len()) + .map(|idx| { + let mut slice = vec![]; + // Iterate over all indices in the input equation + for (i, c) in inputs_eq[idx].chars().enumerate() { + // If the current index is common to multiple inputs, then the slice should be the current coordinate + if let Some(j) = common_indices_to_inputs.iter().position(|&r| r == c) { + slice.push(common_dim[j]..common_dim[j] + 1); + } else { + slice.push(0..inputs[idx].dims()[i]); + } + } + // Get the slice of the input tensor + inputs[idx].get_slice(&slice).unwrap() + }) + .collect::>(); + + let input_pairs = inputs + .iter() + .map(|d| d.iter()) + .multi_cartesian_product() + .collect::>(); + + // Compute the product of all input tensors + for pair in input_pairs { + prod = prod + + pair + .into_iter() + .fold(T::one().unwrap(), |acc, x| acc * x.clone()); + } + } + prod + }) + .collect(); + + let mut output: Tensor = output.into_iter().into(); + output.reshape(&output_shape)?; + + Ok(output) +} + +/// Adds multiple tensors. +/// # Arguments +/// +/// * `t` - Vector of tensors +/// # Examples +/// ``` +/// use ezkl::tensor::Tensor; +/// use ezkl::tensor::ops::add; +/// let x = Tensor::::new( +/// Some(&[2, 1, 2, 1, 1, 1]), +/// &[2, 3], +/// ).unwrap(); +/// let k = Tensor::::new( +/// Some(&[2, 3, 2, 1, 1, 1]), +/// &[2, 3], +/// ).unwrap(); +/// let result = add(&[x, k]).unwrap(); +/// let expected = Tensor::::new(Some(&[4, 4, 4, 2, 2, 2]), &[2, 3]).unwrap(); +/// assert_eq!(result, expected); +/// +/// // Now test 1D casting +/// let x = Tensor::::new( +/// Some(&[2, 1, 2, 1, 1, 1]), +/// &[2, 3], +/// ).unwrap(); +/// let k = Tensor::::new( +/// Some(&[2]), +/// &[1]).unwrap(); +/// let result = add(&[x, k]).unwrap(); +/// let expected = Tensor::::new(Some(&[4, 3, 4, 3, 3, 3]), &[2, 3]).unwrap(); +/// assert_eq!(result, expected); +/// ``` +pub fn add + std::marker::Send + std::marker::Sync>( + t: &[Tensor], +) -> Result, TensorError> { + // calculate value of output + let mut output: Tensor = t[0].clone(); + + for e in t[1..].iter() { + output = output.add(e.clone())?; + } + + Ok(output) +} + +/// Subtracts multiple tensors. +/// # Arguments +/// +/// * `a` - Tensor +/// * `b` - Tensor +/// # Examples +/// ``` +/// use ezkl::tensor::Tensor; +/// use ezkl::tensor::ops::sub; +/// let x = Tensor::::new( +/// Some(&[2, 1, 2, 1, 1, 1]), +/// &[2, 3], +/// ).unwrap(); +/// let k = Tensor::::new( +/// Some(&[2, 3, 2, 1, 1, 1]), +/// &[2, 3], +/// ).unwrap(); +/// let result = sub(&[x, k]).unwrap(); +/// let expected = Tensor::::new(Some(&[0, -2, 0, 0, 0, 0]), &[2, 3]).unwrap(); +/// assert_eq!(result, expected); +/// +/// // Now test 1D sub +/// let x = Tensor::::new( +/// Some(&[2, 1, 2, 1, 1, 1]), +/// &[2, 3], +/// ).unwrap(); +/// let k = Tensor::::new( +/// Some(&[2]), +/// &[1], +/// ).unwrap(); +/// let result = sub(&[x, k]).unwrap(); +/// let expected = Tensor::::new(Some(&[0, -1, 0, -1, -1, -1]), &[2, 3]).unwrap(); +/// assert_eq!(result, expected); +/// ``` +pub fn sub + std::marker::Send + std::marker::Sync>( + t: &[Tensor], +) -> Result, TensorError> { + // calculate value of output + let mut output: Tensor = t[0].clone(); + + for e in t[1..].iter() { + output = (output - e.clone())?; + } + + Ok(output) +} + +/// Negates a tensor. +/// # Arguments +/// +/// * `a` - Tensor +/// # Examples +/// ``` +/// use ezkl::tensor::Tensor; +/// use ezkl::tensor::ops::neg; +/// let x = Tensor::::new( +/// Some(&[2, 1, 2, 1, 1, 1]), +/// &[2, 3], +/// ).unwrap(); +/// let result = neg(&x).unwrap(); +/// let expected = Tensor::::new(Some(&[-2, -1, -2, -1, -1, -1]), &[2, 3]).unwrap(); +/// assert_eq!(result, expected); +/// ``` +pub fn neg + std::marker::Send + std::marker::Sync>( + t: &Tensor, +) -> Result, TensorError> { + // calculate value of output + Ok(-t.clone()) +} + +/// Elementwise multiplies multiple tensors. +/// # Arguments +/// +/// * `a` - Tensor +/// * `b` - Tensor +/// # Examples +/// ``` +/// use ezkl::tensor::Tensor; +/// use ezkl::tensor::ops::mult; +/// let x = Tensor::::new( +/// Some(&[2, 1, 2, 1, 1, 1]), +/// &[2, 3], +/// ).unwrap(); +/// let k = Tensor::::new( +/// Some(&[2, 3, 2, 1, 1, 1]), +/// &[2, 3], +/// ).unwrap(); +/// let result = mult(&[x, k]).unwrap(); +/// let expected = Tensor::::new(Some(&[4, 3, 4, 1, 1, 1]), &[2, 3]).unwrap(); +/// assert_eq!(result, expected); +/// +/// // Now test 1D mult +/// let x = Tensor::::new( +/// Some(&[2, 1, 2, 1, 1, 1]), +/// &[2, 3], +/// ).unwrap(); +/// let k = Tensor::::new( +/// Some(&[2]), +/// &[1]).unwrap(); +/// let result = mult(&[x, k]).unwrap(); +/// let expected = Tensor::::new(Some(&[4, 2, 4, 2, 2, 2]), &[2, 3]).unwrap(); +/// assert_eq!(result, expected); +/// ``` +pub fn mult + std::marker::Send + std::marker::Sync>( + t: &[Tensor], +) -> Result, TensorError> { + // calculate value of output + let mut output: Tensor = t[0].clone(); + + for e in t[1..].iter() { + output = (output * e.clone())?; + } + + Ok(output) +} + +/// Rescale a tensor with a const integer (similar to const_mult). +/// # Arguments +/// +/// * `a` - Tensor +/// * `b` - Single value +/// # Examples +/// ``` +/// use ezkl::tensor::Tensor; +/// use ezkl::tensor::ops::rescale; +/// let x = Tensor::::new( +/// Some(&[2, 1, 2, 1, 1, 1]), +/// &[2, 3], +/// ).unwrap(); +/// let k = 2; +/// let result = rescale(&x, k).unwrap(); +/// let expected = Tensor::::new(Some(&[4, 2, 4, 2, 2, 2]), &[2, 3]).unwrap(); +/// assert_eq!(result, expected); +/// ``` +pub fn rescale + std::marker::Send + std::marker::Sync>( + a: &Tensor, + mult: u128, +) -> Result, TensorError> { + // calculate value of output + let mut output: Tensor = a.clone(); + output.par_iter_mut().enumerate().for_each(|(i, a_i)| { + for _ in 1..mult { + *a_i = a_i.clone() + a[i].clone(); + } + }); + Ok(output) +} + +/// Sums a tensor. +/// # Arguments +/// +/// * `a` - Tensor +/// * `b` - Single value +/// # Examples +/// ``` +/// use ezkl::tensor::Tensor; +/// use ezkl::tensor::ops::sum; +/// let x = Tensor::::new( +/// Some(&[2, 15, 2, 1, 1, 0]), +/// &[2, 3], +/// ).unwrap(); +/// let result = sum(&x).unwrap(); +/// let expected = 21; +/// assert_eq!(result[0], expected); +/// ``` +pub fn sum>(a: &Tensor) -> Result, TensorError> { + // calculate value of output + let mut res = T::zero().unwrap(); + + let _ = a.map(|a_i| res = res.clone() + a_i); + Tensor::new(Some(&[res]), &[1]) +} + +/// Takes prod of tensor's elements. +/// # Arguments +/// +/// * `a` - Tensor +/// * `b` - Single value +/// # Examples +/// ``` +/// use ezkl::tensor::Tensor; +/// use ezkl::tensor::ops::prod; +/// let x = Tensor::::new( +/// Some(&[2, 15, 2, 1, 1, 0]), +/// &[2, 3], +/// ).unwrap(); +/// let result = prod(&x).unwrap(); +/// let expected = 0; +/// assert_eq!(result[0], expected); +/// ``` +pub fn prod>(a: &Tensor) -> Result, TensorError> { + // calculate value of output + let mut res = T::one().unwrap(); + + let _ = a.map(|a_i| res = res.clone() * a_i); + Tensor::new(Some(&[res]), &[1]) +} + +/// Downsamples a tensor along a dimension. +/// # Arguments +/// * `input` - Tensor +/// * `dim` - Dimension to downsample along +/// * `stride` - Stride to downsample by +/// * `modulo` - Modulo to downsample by +/// # Examples +/// ``` +/// use ezkl::tensor::Tensor; +/// use ezkl::tensor::ops::downsample; +/// let x = Tensor::::new( +/// Some(&[1, 2, 3, 4, 5, 6]), +/// &[2, 3], +/// ).unwrap(); +/// let result = downsample(&x, 0, 1, 1).unwrap(); +/// let expected = Tensor::::new(Some(&[4, 5, 6]), &[1, 3]).unwrap(); +/// assert_eq!(result, expected); +/// +/// let result = downsample(&x, 1, 2, 0).unwrap(); +/// let expected = Tensor::::new(Some(&[1, 3, 4, 6]), &[2, 2]).unwrap(); +/// assert_eq!(result, expected); +/// +/// let result = downsample(&x, 1, 2, 1).unwrap(); +/// let expected = Tensor::::new(Some(&[2, 5]), &[2, 1]).unwrap(); +/// assert_eq!(result, expected); +/// +/// let result = downsample(&x, 1, 2, 2).unwrap(); +/// let expected = Tensor::::new(Some(&[3, 6]), &[2, 1]).unwrap(); +/// assert_eq!(result, expected); +pub fn downsample( + input: &Tensor, + dim: usize, + stride: usize, + modulo: usize, +) -> Result, TensorError> { + let mut output_shape = input.dims().to_vec(); + // now downsample along axis dim offset by modulo, rounding up (+1 if remaidner is non-zero) + let remainder = (input.dims()[dim] - modulo) % stride; + let div = (input.dims()[dim] - modulo) / stride; + output_shape[dim] = div + (remainder > 0) as usize; + let mut output = Tensor::::new(None, &output_shape)?; + + if modulo > input.dims()[dim] { + return Err(TensorError::DimMismatch("downsample".to_string())); + } + + // now downsample along axis dim offset by modulo + let indices = (0..output_shape.len()) + .map(|i| { + if i == dim { + let mut index = vec![0; output_shape[i]]; + for (i, idx) in index.iter_mut().enumerate() { + *idx = i * stride + modulo; + } + index + } else { + (0..output_shape[i]).collect_vec() + } + }) + .multi_cartesian_product() + .collect::>(); + + output = output.par_enum_map(|i, _: T| { + let coord = indices[i].clone(); + Ok(input.get(&coord)) + })?; + + Ok(output) +} + +/// Gathers a tensor along a dimension. +/// # Arguments +/// * `input` - Tensor +/// * `dim` - Dimension to gather along +/// * `index` - Tensor of indices to gather +/// # Examples +/// ``` +/// use ezkl::tensor::Tensor; +/// use ezkl::tensor::ops::gather; +/// let x = Tensor::::new( +/// Some(&[1, 2, 3, 4, 5, 6]), +/// &[2, 3], +/// ).unwrap(); +/// let index = Tensor::::new( +/// Some(&[0, 1]), +/// &[2], +/// ).unwrap(); +/// let result = gather(&x, &index, 1).unwrap(); +/// let expected = Tensor::::new(Some(&[1, 2, 4, 5]), &[2, 2]).unwrap(); +/// assert_eq!(result, expected); +/// ``` +pub fn gather( + input: &Tensor, + index: &Tensor, + dim: usize, +) -> Result, TensorError> { + let mut index_clone = index.clone(); + index_clone.flatten(); + if index_clone.is_singleton() { + index_clone.reshape(&[1])?; + } + + // Calculate the output tensor size + let mut output_size = input.dims().to_vec(); + output_size[dim] = index_clone.dims()[0]; + + // Allocate memory for the output tensor + let mut output = Tensor::new(None, &output_size)?; + let cartesian_coord = output_size + .iter() + .map(|x| 0..*x) + .multi_cartesian_product() + .collect::>(); + + output = output.par_enum_map(|i, _: T| { + let coord = cartesian_coord[i].clone(); + let index_val = index_clone.get(&[coord[dim]]); + let new_coord = coord + .iter() + .enumerate() + .map(|(i, x)| if i == dim { index_val } else { *x }) + .collect::>(); + + Ok(input.get(&new_coord)) + })?; + + // Reshape the output tensor + if index.is_singleton() { + output_size.remove(dim); + } + + output.reshape(&output_size)?; + + Ok(output) +} + +/// Scatters a tensor along a dimension. +/// # Arguments +/// * `input` - Tensor +/// * `dim` - Dimension to scatter along +/// * `index` - Tensor of indices to scatter +/// # Examples +/// ``` +/// use ezkl::tensor::Tensor; +/// use ezkl::tensor::ops::scatter; +/// let x = Tensor::::new( +/// Some(&[1.0, 2.0, 3.0, 4.0]), +/// &[2, 2], +/// ).unwrap(); +/// let src = Tensor::::new( +/// Some(&[5.0, 6.0, 7.0, 8.0]), +/// &[2, 2], +/// ).unwrap(); +/// let index = Tensor::::new( +/// Some(&[0, 0, 1, 0]), +/// &[2, 2], +/// ).unwrap(); +/// let result = scatter(&x, &index, &src, 0).unwrap(); +/// let expected = Tensor::::new(Some(&[5.0, 8.0, 7.0, 4.0]), &[2, 2]).unwrap(); +/// assert_eq!(result, expected); +/// ``` +pub fn scatter( + input: &Tensor, + index: &Tensor, + src: &Tensor, + dim: usize, +) -> Result, TensorError> { + // Writes all values from the tensor src into self at the indices specified in the index tensor. + // For each value in src, its output index is specified by its index in src for dimension != dim and by the corresponding value in index for dimension = dim. + assert_eq!(index.dims(), src.dims()); + // Calculate the output tensor size + let src_size = src.dims().to_vec(); + + // For a 3-D tensor, self is updated as: + // self[index[i][j][k]][j][k] = src[i][j][k] # if dim == 0 + // self[i][index[i][j][k]][k] = src[i][j][k] # if dim == 1 + // self[i][j][index[i][j][k]] = src[i][j][k] # if dim == 2 + + let mut output = input.clone(); + + let cartesian_coord = src_size + .iter() + .map(|x| 0..*x) + .multi_cartesian_product() + .collect::>(); + + cartesian_coord.iter().for_each(|coord| { + let mut new_coord = coord.clone(); + let index_val = index.get(coord); + new_coord[dim] = index_val; + let val = src.get(coord); + output.set(&new_coord, val); + }); + + Ok(output) +} + +/// Gathers a tensor along a dimension. +/// # Arguments +/// * `input` - Tensor +/// * `dim` - Dimension to gather along +/// * `index` - Tensor of indices to gather +/// # Examples +/// ``` +/// use ezkl::tensor::Tensor; +/// use ezkl::tensor::ops::gather_elements; +/// let x = Tensor::::new( +/// Some(&[1, 2, 3, 4]), +/// &[2, 2], +/// ).unwrap(); +/// let index = Tensor::::new( +/// Some(&[0, 0, 1, 0]), +/// &[2, 2], +/// ).unwrap(); +/// let result = gather_elements(&x, &index, 1).unwrap(); +/// let expected = Tensor::::new(Some(&[1, 1, 4, 3]), &[2, 2]).unwrap(); +/// assert_eq!(result, expected); +/// ``` +pub fn gather_elements( + input: &Tensor, + index: &Tensor, + dim: usize, +) -> Result, TensorError> { + // Calculate the output tensor size + let output_size = index.dims().to_vec(); + // same rank + assert_eq!(input.dims().len(), index.dims().len()); + + // Allocate memory for the output tensor + let mut output = Tensor::new(None, &output_size)?; + let cartesian_coord = output_size + .iter() + .map(|x| 0..*x) + .multi_cartesian_product() + .collect::>(); + + output = output.par_enum_map(|i, _: T| { + let coord = cartesian_coord[i].clone(); + let index_val = index.get(&coord); + + let mut new_coord = coord.clone(); + new_coord[dim] = index_val; + + let val = input.get(&new_coord); + + Ok(val) + })?; + + // Reshape the output tensor + output.reshape(&output_size)?; + + Ok(output) +} + +fn axes_op( + a: &Tensor, + axes: &[usize], + op: impl Fn(&Tensor) -> Result, TensorError> + Send + Sync, +) -> Result, TensorError> { + // calculate value of output + + if axes.is_empty() { + return Ok(a.clone()); + } + + let mut new_dims = vec![]; + for i in 0..a.dims().len() { + if !axes.contains(&i) { + new_dims.push(a.dims()[i]); + } else { + new_dims.push(1); + } + } + + let res = Tensor::new(None, &new_dims)?; + + let cartesian_coord = new_dims + .iter() + .map(|x| 0..*x) + .multi_cartesian_product() + .collect::>(); + + let res = res.par_enum_map(|i, _: T| { + let coord = cartesian_coord[i].clone(); + let mut prod_dims = vec![]; + for (i, c) in coord.iter().enumerate() { + if axes.contains(&i) { + prod_dims.push(0..a.dims()[i]); + } else { + prod_dims.push(*c..*c + 1); + } + } + + Ok(op(&a.get_slice(&prod_dims)?)?[0].clone()) + })?; + + Ok(res) +} + +/// Takes product of a tensor along specific axes. +/// # Arguments +/// +/// * `a` - Tensor +/// * `b` - Single value +/// # Examples +/// ``` +/// use ezkl::tensor::Tensor; +/// use ezkl::tensor::ops::prod_axes; +/// let x = Tensor::::new( +/// Some(&[2, 15, 2, 1, 1, 0]), +/// &[2, 3], +/// ).unwrap(); +/// let result = prod_axes(&x, &[1]).unwrap(); +/// let expected = Tensor::::new( +/// Some(&[60, 0]), +/// &[2, 1], +/// ).unwrap(); +/// assert_eq!(result, expected); +/// ``` +pub fn prod_axes + Send + Sync>( + a: &Tensor, + axes: &[usize], +) -> Result, TensorError> { + // calculate value of output + axes_op(a, axes, prod) +} + +/// Returns top K values. +/// # Arguments +/// +/// * `a` - Tensor +/// * `b` - Single value +/// # Examples +/// ``` +/// use ezkl::tensor::Tensor; +/// use ezkl::tensor::ops::topk; +/// let x = Tensor::::new( +/// Some(&[2, 15, 2, 1, 1, 0]), +/// &[6], +/// ).unwrap(); +/// let result = topk(&x, 3).unwrap(); +/// let expected = Tensor::::new( +/// Some(&[15, 2, 2]), +/// &[3], +/// ).unwrap(); +/// assert_eq!(result, expected); +/// ``` +pub fn topk(a: &Tensor, k: usize) -> Result, TensorError> { + let mut indexed_a = a.clone(); + indexed_a.flatten(); + + let mut indexed_a = a + .iter() + .enumerate() + .map(|(i, x)| (i, x)) + .collect::>(); + + indexed_a.sort_by(|(_, a), (_, b)| b.partial_cmp(a).unwrap()); + + let indexed_a = indexed_a + .into_iter() + .take(k) + .map(|(i, _)| i) + .collect::>(); + + let mut output = Tensor::new(None, &[k])?; + + for (i, x) in indexed_a.iter().enumerate() { + output.set(&[i], a[*x].clone()); + } + + Ok(output) +} + +/// Returns top K values. +/// # Arguments +/// +/// * `a` - Tensor +/// * `b` - Single value +/// # Examples +/// ``` +/// use ezkl::tensor::Tensor; +/// use ezkl::tensor::ops::topk_axes; +/// let x = Tensor::::new( +/// Some(&[2, 15, 2, 1, 1, 0]), +/// &[2,3], +/// ).unwrap(); +/// let result = topk_axes(&x, 2, 1).unwrap(); +/// let expected = Tensor::::new( +/// Some(&[15, 2, 1, 1]), +/// &[2,2], +/// ).unwrap(); +/// assert_eq!(result, expected); +/// ``` +pub fn topk_axes( + a: &Tensor, + k: usize, + dim: usize, +) -> Result, TensorError> { + let mut new_dims = a.dims().to_vec(); + new_dims[dim] = k; + + let res = Tensor::new(None, &new_dims)?; + + let cartesian_coord = new_dims + .iter() + .map(|x| 0..*x) + .multi_cartesian_product() + .collect::>(); + + let res = res.par_enum_map(|i, _: T| { + let coord = cartesian_coord[i].clone(); + let mut slice = vec![]; + for (i, c) in coord.iter().enumerate() { + if i == dim { + slice.push(0..a.dims()[i]); + } else { + slice.push(*c..*c + 1); + } + } + let sliced_value = a.get_slice(&slice)?; + let topk = topk(&sliced_value, k)?; + Ok(topk[coord[dim]].clone()) + })?; + + Ok(res) +} + +/// Sums a tensor along specific axes. +/// # Arguments +/// +/// * `a` - Tensor +/// * `b` - Single value +/// # Examples +/// ``` +/// use ezkl::tensor::Tensor; +/// use ezkl::tensor::ops::sum_axes; +/// let x = Tensor::::new( +/// Some(&[2, 15, 2, 1, 1, 0]), +/// &[2, 3], +/// ).unwrap(); +/// let result = sum_axes(&x, &[1]).unwrap(); +/// let expected = Tensor::::new( +/// Some(&[19, 2]), +/// &[2, 1], +/// ).unwrap(); +/// assert_eq!(result, expected); +/// ``` +pub fn sum_axes + Send + Sync>( + a: &Tensor, + axes: &[usize], +) -> Result, TensorError> { + // calculate value of output + axes_op(a, axes, sum) +} + +/// Mins a tensor along specific axes. +/// # Arguments +/// +/// * `a` - Tensor +/// * `b` - Single value +/// # Examples +/// ``` +/// use ezkl::tensor::Tensor; +/// use ezkl::tensor::ops::min_axes; +/// let x = Tensor::::new( +/// Some(&[2, 15, 2, 1, 1, 0]), +/// &[2, 3], +/// ).unwrap(); +/// let result = min_axes(&x, &[1]).unwrap(); +/// let expected = Tensor::::new( +/// Some(&[2, 0]), +/// &[2, 1], +/// ).unwrap(); +/// assert_eq!(result, expected); +/// ``` +pub fn min_axes + std::cmp::Ord + Send + Sync>( + a: &Tensor, + axes: &[usize], +) -> Result, TensorError> { + // calculate value of output + + let min_fn = |a: &Tensor| -> Result, TensorError> { + Ok(vec![a.par_iter().min().unwrap().clone()].into_iter().into()) + }; + + axes_op(a, axes, min_fn) +} + +/// Abs a tensor. +/// # Arguments +/// * `a` - Tensor +/// # Examples +/// ``` +/// use ezkl::tensor::Tensor; +/// use ezkl::tensor::ops::abs; +/// let x = Tensor::::new( +/// Some(&[-2, 15, 2, -1, 1, 0]), +/// &[2, 3], +/// ).unwrap(); +/// let result = abs(&x).unwrap(); +/// let expected = Tensor::::new(Some(&[2, 15, 2, 1, 1, 0]), &[2, 3]).unwrap(); +/// assert_eq!(result, expected); +/// ``` +pub fn abs + std::cmp::Ord + Neg>( + a: &Tensor, +) -> Result, TensorError> { + // calculate value of output + let mut output: Tensor = a.clone(); + output.iter_mut().for_each(|a_i| { + if *a_i < T::zero().unwrap() { + *a_i = -a_i.clone(); + } + }); + Ok(output) +} + +/// Max of a tensor along specific axes. +/// # Arguments +/// +/// * `a` - Tensor +/// * `b` - Single value +/// # Examples +/// ``` +/// use ezkl::tensor::Tensor; +/// use ezkl::tensor::ops::max_axes; +/// let x = Tensor::::new( +/// Some(&[2, 15, 2, 1, 1, 0]), +/// &[2, 3], +/// ).unwrap(); +/// let result = max_axes(&x, &[1]).unwrap(); +/// let expected = Tensor::::new( +/// Some(&[15, 1]), +/// &[2, 1], +/// ).unwrap(); +/// assert_eq!(result, expected); +/// ``` +pub fn max_axes + std::cmp::Ord + Send + Sync>( + a: &Tensor, + axes: &[usize], +) -> Result, TensorError> { + // calculate value of output + + let max_fn = |a: &Tensor| -> Result, TensorError> { + Ok(vec![a.par_iter().max().unwrap().clone()].into_iter().into()) + }; + + axes_op(a, axes, max_fn) +} + +/// Argmax of a tensor along specific axes. +/// # Arguments +/// +/// * `a` - Tensor +/// * `b` - Single value +/// # Examples +/// ``` +/// use ezkl::tensor::Tensor; +/// use ezkl::tensor::ops::argmax_axes; +/// let x = Tensor::::new( +/// Some(&[2, 15, 2, 1, 1, 0]), +/// &[2, 3], +/// ).unwrap(); +/// let result = argmax_axes(&x, 1).unwrap(); +/// let expected = Tensor::::new( +/// Some(&[1, 0]), +/// &[2, 1], +/// ).unwrap(); +/// assert_eq!(result, expected); +/// ``` +pub fn argmax_axes + std::cmp::Ord + From + Send + Sync>( + a: &Tensor, + dim: usize, +) -> Result, TensorError> { + let argmax_fn = |a: &Tensor| -> Result, TensorError> { + Ok(vec![a + .clone() + .into_iter() + .enumerate() + // we value the first index in the case of a tie + .max_by_key(|(idx, value)| (value.clone(), -(*idx as i64))) + .map(|(idx, _)| T::from(idx as u64)) + .unwrap()] + .into_iter() + .into()) + }; + + // calculate value of output + axes_op(a, &[dim], argmax_fn) +} + +/// Argmin of a tensor along specific axes. +/// # Arguments +/// +/// * `a` - Tensor +/// * `b` - Single value +/// # Examples +/// ``` +/// use ezkl::tensor::Tensor; +/// use ezkl::tensor::ops::argmin_axes; +/// let x = Tensor::::new( +/// Some(&[2, 15, 0, 1, 1, 0]), +/// &[2, 3], +/// ).unwrap(); +/// let result = argmin_axes(&x, 0).unwrap(); +/// let expected = Tensor::::new( +/// Some(&[1, 1, 0]), +/// &[1, 3], +/// ).unwrap(); +/// assert_eq!(result, expected); +/// ``` +pub fn argmin_axes + std::cmp::Ord + From + Send + Sync>( + a: &Tensor, + dim: usize, +) -> Result, TensorError> { + let argmax_fn = |a: &Tensor| -> Result, TensorError> { + Ok(vec![a + .clone() + .into_iter() + .enumerate() + // we value the first index in the case of a tie + .min_by_key(|(idx, value)| (value.clone(), (*idx as i64))) + .map(|(idx, _)| T::from(idx as u64)) + .unwrap()] + .into_iter() + .into()) + }; + + // calculate value of output + axes_op(a, &[dim], argmax_fn) +} + +/// Applies convolution over a 3D tensor of shape C x H x W (and adds a bias). +/// # Arguments +/// +/// * `inputs` - A vector of tensors holding in order: input image, convolution kernel, convolution bias. +/// * `padding` - Tuple of padding values in x and y directions. +/// * `stride` - Tuple of stride values in x and y directions. +/// # Examples +/// ``` +/// // expected ouputs are taken from pytorch torch.nn.functional.conv2d +/// +/// use ezkl::tensor::Tensor; +/// use ezkl::tensor::ops::conv; +/// +/// let x = Tensor::::new( +/// Some(&[5, 2, 3, 0, 4, -1, 3, 1, 6]), +/// &[1, 1, 3, 3], +/// ).unwrap(); +/// let k = Tensor::::new( +/// Some(&[5, 1, 1, 1]), +/// &[1, 1, 2, 2], +/// ).unwrap(); +/// let b = Tensor::::new( +/// Some(&[0]), +/// &[1], +/// ).unwrap(); +/// let result = conv::(&[x, k, b], [(0, 0); 2], (1, 1)).unwrap(); +/// let expected = Tensor::::new(Some(&[31, 16, 8, 26]), &[1, 1, 2, 2]).unwrap(); +/// assert_eq!(result, expected); +/// +/// // Now test single channel +/// let x = Tensor::::new( +/// Some(&[5, 2, 3, 0, 4, -1, 3, 1, 6, 5, 2, 3, 0, 4, -1, 3, 1, 6]), +/// &[1, 2, 3, 3], +/// ).unwrap(); +/// let k = Tensor::::new( +/// Some(&[5, 1, 1, 1, 5, 2, 1, 1]), +/// &[2, 1, 2, 2], +/// ).unwrap(); +/// let b = Tensor::::new( +/// Some(&[1, 1]), +/// &[2], +/// ).unwrap(); +/// +/// let result = conv::(&[x, k, b], [(0, 0); 2], (1, 1)).unwrap(); +/// let expected = Tensor::::new(Some(&[32, 17, 9, 27, 34, 20, 13, 26]), &[1, 2, 2, 2]).unwrap(); +/// assert_eq!(result, expected); +/// +/// // Now test multi channel +/// let x = Tensor::::new( +/// Some(&[5, 2, 3, 0, 4, -1, 3, 1, 6, 5, 2, 3, 0, 4, -1, 3, 1, 6]), +/// &[1, 2, 3, 3], +/// ).unwrap(); +/// let k = Tensor::::new( +/// Some(&[5, 1, 1, 1, 5, 2, 1, 1, 5, 3, 1, 1, 5, 4, 1, 1, 5, 1, 1, 1, 5, 2, 1, 1, 5, 3, 1, 1, 5, 4, 1, 1]), +/// &[4, 2, 2, 2], +/// ).unwrap(); +/// let b = Tensor::::new( +/// Some(&[1, 1, 1, 1]), +/// &[4], +/// ).unwrap(); +/// +/// let result = conv::(&[x, k, b], [(0, 0); 2], (1, 1)).unwrap(); +/// let expected = Tensor::::new(Some(&[65, 36, 21, 52, 73, 48, 37, 48, 65, 36, 21, 52, 73, 48, 37, 48]), &[1, 4, 2, 2]).unwrap(); +/// assert_eq!(result, expected); +/// ``` +pub fn conv< + T: TensorType + + Mul + + Add + + std::marker::Sync + + std::marker::Send + + std::iter::Sum, +>( + inputs: &[Tensor], + padding: [(usize, usize); 2], + stride: (usize, usize), +) -> Result, TensorError> { + let has_bias = inputs.len() == 3; + let (image, kernel) = (&mut inputs[0].clone(), &mut inputs[1].clone()); + + let og_image_dims = image.dims().to_vec(); + let og_kernel_dims = kernel.dims().to_vec(); + // ensure inputs are 4D tensors + if og_image_dims.len() == 3 { + // adds a dummy image_channels dimension + let mut new_dims = image.dims().to_vec(); + // insert 1 at the input_channels pos + if og_kernel_dims.len() == 3 { + new_dims.insert(1, 1); + } else { + new_dims.insert(0, 1); + } + image.reshape(&new_dims)?; + } + + // ensure kernel is 4D tensor + if og_kernel_dims.len() == 3 && og_image_dims.len() == 3 { + // adds a dummy image_channels dimension + let mut new_dims = kernel.dims().to_vec(); + // insert 1 at the input_channels pos + new_dims.insert(1, 1); + kernel.reshape(&new_dims)?; + } + + if (image.dims().len() != 4) + || (kernel.dims().len() != 4) + // ensure number of groups makes sense + || (image.dims()[1] % kernel.dims()[1] != 0) + { + return Err(TensorError::DimMismatch("conv".to_string())); + } + + let image_dims = image.dims(); + let kernel_dims = kernel.dims(); + + if has_bias { + let bias = &inputs[2]; + if (bias.dims().len() != 1) || (bias.dims()[0] != kernel.dims()[0]) { + return Err(TensorError::DimMismatch("conv bias".to_string())); + } + } + + let (batch_size, output_channels, input_channels, kernel_height, kernel_width) = ( + image_dims[0], + kernel_dims[0], + image_dims[1], + kernel_dims[2], + kernel_dims[3], + ); + + let (image_height, image_width) = (image_dims[2], image_dims[3]); + + let padded_image = pad::(image, padding)?; + + let vert_slides = (image_height + padding[0].0 + padding[1].0 - kernel_height) / stride.0 + 1; + let horz_slides = (image_width + padding[0].1 + padding[1].1 - kernel_width) / stride.1 + 1; + + let num_groups = input_channels / kernel_dims[1]; + let input_channels_per_group = input_channels / num_groups; + let output_channels_per_group = output_channels / num_groups; + + if output_channels_per_group == 0 { + return Err(TensorError::DimMismatch(format!( + "Given groups={}, expected kernel to be at least {} at dimension 0 but got {} instead", + num_groups, num_groups, output_channels_per_group + ))); + } + + let num_outputs = + batch_size * num_groups * output_channels_per_group * vert_slides * horz_slides; + + let mut output = Tensor::new(None, &[num_outputs])?; + + let cartesian_coord = [ + (0..batch_size), + (0..num_groups), + (0..output_channels_per_group), + (0..vert_slides), + (0..horz_slides), + ] + .iter() + .cloned() + .multi_cartesian_product() + .collect::>(); + + output.par_iter_mut().enumerate().for_each(|(i, o)| { + let cartesian_coord_per_group = &cartesian_coord[i]; + let (batch, group, i, j, k) = ( + cartesian_coord_per_group[0], + cartesian_coord_per_group[1], + cartesian_coord_per_group[2], + cartesian_coord_per_group[3], + cartesian_coord_per_group[4], + ); + let rs = j * stride.0; + let cs = k * stride.1; + + let start_channel = group * input_channels_per_group; + let end_channel = start_channel + input_channels_per_group; + + let local_image = padded_image + .get_slice(&[ + batch..batch + 1, + start_channel..end_channel, + rs..(rs + kernel_height), + cs..(cs + kernel_width), + ]) + .unwrap(); + + let start_kernel_index = group * output_channels_per_group + i; + let end_kernel_index = start_kernel_index + 1; + let local_kernel = kernel + .get_slice(&[start_kernel_index..end_kernel_index]) + .unwrap(); + + let res = dot(&[local_image, local_kernel]).unwrap()[0].clone(); + if has_bias { + *o = res + inputs[2][start_kernel_index].clone(); + } else { + *o = res; + } + }); + + // remove dummy batch dimension if we added one + if og_image_dims.len() == 3 && vert_slides == 1 { + output.reshape(&[batch_size, output_channels, horz_slides])?; + } else if og_image_dims.len() == 3 { + output.reshape(&[output_channels, vert_slides, horz_slides])?; + } else { + output.reshape(&[batch_size, output_channels, vert_slides, horz_slides])?; + } + + Ok(output) +} + +/// Intercalates values into a tensor along a given axis. +/// ``` +/// use ezkl::tensor::Tensor; +/// use ezkl::tensor::ops::intercalate_values; +/// +/// let tensor = Tensor::::new(Some(&[1, 2, 3, 4]), &[2, 2]).unwrap(); +/// let result = intercalate_values(&tensor, 0, 2, 1).unwrap(); +/// +/// let expected = Tensor::::new(Some(&[1, 0, 2, 3, 0, 4]), &[2, 3]).unwrap(); +/// assert_eq!(result, expected); +/// +/// let result = intercalate_values(&expected, 0, 2, 0).unwrap(); +/// let expected = Tensor::::new(Some(&[1, 0, 2, 0, 0, 0, 3, 0, 4]), &[3, 3]).unwrap(); +/// +/// assert_eq!(result, expected); +/// +/// ``` +pub fn intercalate_values( + tensor: &Tensor, + value: T, + stride: usize, + axis: usize, +) -> Result, TensorError> { + if stride == 1 { + return Ok(tensor.clone()); + } + + let mut output_dims = tensor.dims().to_vec(); + output_dims[axis] = output_dims[axis] * stride - 1; + + let mut output: Tensor = Tensor::new(None, &output_dims)?; + + let cartesian_coord = output + .dims() + .iter() + .map(|d| (0..*d)) + .multi_cartesian_product() + .collect::>(); + + let mut tensor_slice_iter = tensor.iter(); + + output.iter_mut().enumerate().for_each(|(i, o)| { + let coord = &cartesian_coord[i]; + + if coord[axis] % stride == 0 { + *o = tensor_slice_iter.next().unwrap().clone(); + } else { + *o = value.clone(); + } + }); + + Ok(output) +} + +/// One hot encodes a tensor along a given axis. +/// ``` +/// use ezkl::tensor::Tensor; +/// use ezkl::tensor::ops::one_hot; +/// let tensor = Tensor::::new(Some(&[1, 2, 3, 4]), &[2, 2]).unwrap(); +/// let result = one_hot(&tensor, 5, 2).unwrap(); +/// let expected = Tensor::::new(Some(&[0, 1, 0, 0, 0, +/// 0, 0, 1, 0, 0, +/// 0, 0, 0, 1, 0, +/// 0, 0, 0, 0, 1]), &[2, 2, 5]).unwrap(); +/// assert_eq!(result, expected); +/// ``` +pub fn one_hot( + tensor: &Tensor, + num_classes: usize, + axis: usize, +) -> Result, TensorError> { + let mut output_dims = tensor.dims().to_vec(); + output_dims.insert(axis, num_classes); + + let mut output: Tensor = Tensor::new(None, &output_dims)?; + + let cartesian_coord = output + .dims() + .iter() + .map(|d| (0..*d)) + .multi_cartesian_product() + .collect::>(); + + output + .iter_mut() + .enumerate() + .map(|(i, o)| { + let coord = &cartesian_coord[i]; + let coord_axis = coord[axis]; + + let mut coord_without_axis = coord.clone(); + coord_without_axis.remove(axis); + + let elem = tensor.get(&coord_without_axis) as usize; + if elem > num_classes { + return Err(TensorError::DimMismatch(format!( + "Expected element to be less than num_classes, but got {}", + elem + ))); + }; + + if coord_axis == elem { + *o = 1; + } else { + *o = 0; + } + Ok(()) + }) + .collect::, TensorError>>()?; + + Ok(output) +} + +/// Performs a 2D deconvolution on the given input tensor. +/// # Examples +/// ``` +// // expected ouputs are taken from pytorch torch.nn.functional.conv_transpose2d +/// +/// use ezkl::tensor::Tensor; +/// use ezkl::tensor::ops::deconv; +/// +/// let c = Tensor::::new(Some(&[6, 0, 12, 4, 0, 8, 0, 0, 3, 0, 0, 2]), &[1, 2, 2, 3]).unwrap(); +/// let x = Tensor::::new( +/// Some(&[2, 4, 0, 1]), +/// &[1, 1, 2, 2], +/// ).unwrap(); +/// +/// let result = deconv::(&[x, c], [(1, 1); 2], (1, 1), (2, 2)).unwrap(); +/// let expected = Tensor::::new(Some(&[0, 32, 0, 32, 0, 6, 0, 12, 0, 4, 0, 8, 0, 4, 0, 8, 0, 0, 0, 3, 0, 0, 0, 2]), &[1, 2, 3, 4]).unwrap(); +/// assert_eq!(result, expected); +/// +/// let x = Tensor::::new( +/// Some(&[2, 4, 0, 1]), +/// &[1, 1, 2, 2], +/// ).unwrap(); +/// let k = Tensor::::new( +/// Some(&[3, 1, 1, 5]), +/// &[1, 1, 2, 2], +/// ).unwrap(); +/// let result = deconv::(&[x, k], [(0, 0); 2], (0, 0), (1, 1)).unwrap(); +/// let expected = Tensor::::new(Some(&[6, 14, 4, 2, 17, 21, 0, 1, 5]), &[1, 1, 3, 3]).unwrap(); +/// assert_eq!(result, expected); +/// +/// +/// let x = Tensor::::new( +/// Some(&[2, 4, 0, 1]), +/// &[1, 1, 2, 2], +/// ).unwrap(); +/// let k = Tensor::::new( +/// Some(&[3, 1, 1, 5]), +/// &[1, 1, 2, 2], +/// ).unwrap(); +/// let result = deconv::(&[x, k], [(1, 1); 2], (0, 0), (1, 1)).unwrap(); +/// let expected = Tensor::::new(Some(&[17]), &[1, 1, 1, 1]).unwrap(); +/// assert_eq!(result, expected); +/// +/// +/// let x = Tensor::::new( +/// Some(&[2, 4, 0, 1]), +/// &[1, 1, 2, 2], +/// ).unwrap(); +/// let k = Tensor::::new( +/// Some(&[3, 1, 1, 5]), +/// &[1, 1, 2, 2], +/// ).unwrap(); +/// let result = deconv::(&[x, k], [(1, 1); 2], (0, 0), (2, 2)).unwrap(); +/// let expected = Tensor::::new(Some(&[10, 4, 0, 3]), &[1, 1, 2, 2]).unwrap(); +/// assert_eq!(result, expected); +/// +/// let x = Tensor::::new( +/// Some(&[2, 4, 0, 1]), +/// &[1, 1, 2, 2], +/// ).unwrap(); +/// let k = Tensor::::new( +/// Some(&[3, 1, 1, 5]), +/// &[1, 1, 2, 2], +/// ).unwrap(); +/// let result = deconv::(&[x, k], [(0, 0); 2], (0, 0), (2, 2)).unwrap(); +/// let expected = Tensor::::new(Some(&[6, 2, 12, 4, 2, 10, 4, 20, 0, 0, 3, 1, 0, 0, 1, 5]), &[1, 1, 4, 4]).unwrap(); +/// assert_eq!(result, expected); +/// +/// let x = Tensor::::new( +/// Some(&[2, 4, 0, 1]), +/// &[1, 1, 2, 2], +/// ).unwrap(); +/// let k = Tensor::::new( +/// Some(&[3, 2]), +/// &[1, 1, 2, 1], +/// ).unwrap(); +/// let result = deconv::(&[x, k], [(1, 1); 2], (0, 0), (2, 2)).unwrap(); +/// let expected = Tensor::::new(Some(&[0, 0]), &[1, 1, 2, 1]).unwrap(); +/// assert_eq!(result, expected); +/// +/// let x = Tensor::::new( +/// Some(&[2, 4, 0, 1]), +/// &[1, 1, 2, 2], +/// ).unwrap(); +/// let k = Tensor::::new( +/// Some(&[3, 2]), +/// &[1, 1, 2, 1], +/// ).unwrap(); +/// let result = deconv::(&[x, k], [(0, 0); 2], (0, 0), (2, 2)).unwrap(); +/// let expected = Tensor::::new(Some(&[6, 0, 12, 4, 0, 8, 0, 0, 3, 0, 0, 2]), &[1, 1, 4, 3]).unwrap(); +/// assert_eq!(result, expected); +/// +/// +/// let c = Tensor::::new(Some(&[6, 0, 12, 4, 0, 8, 0, 0, 3, 0, 0, 2]), &[1, 2, 2, 3]).unwrap(); +/// let x = Tensor::::new( +/// Some(&[2, 4, 0, 1]), +/// &[1, 1, 2, 2], +/// ).unwrap(); +/// +/// let result = deconv::(&[x, c], [(1, 1); 2], (0, 0), (2, 2)).unwrap(); +/// let expected = Tensor::::new(Some(&[0, 32, 0, 0, 6, 0, 0, 4, 0, 0, 0, 0]), &[1, 2, 2, 3]).unwrap(); +/// assert_eq!(result, expected); +/// +/// +/// let x = Tensor::::new( +/// Some(&[3, 8, 0, 8, 4, 9, 8, 1, 8]), +/// &[1, 1, 3, 3], +/// ).unwrap(); +/// let k = Tensor::::new( +/// Some(&[1, 0, 4, 6]), +/// &[1, 1, 2, 2], +/// ).unwrap(); +/// let b = Tensor::::new( +/// Some(&[1]), +/// &[1], +/// ).unwrap(); +/// let result = deconv::(&[x, k, b], [(1, 1); 2], (0, 0), (1, 1)).unwrap(); +/// let expected = Tensor::::new(Some(&[55, 58, 66, 69]), &[1, 1, 2, 2]).unwrap(); +/// assert_eq!(result, expected); +/// +/// ``` +pub fn deconv< + T: TensorType + + Mul + + Add + + std::marker::Sync + + std::marker::Send + + std::iter::Sum, +>( + inputs: &[Tensor], + padding: [(usize, usize); 2], + output_padding: (usize, usize), + stride: (usize, usize), +) -> Result, TensorError> { + let has_bias = inputs.len() == 3; + let (image, kernel) = (&inputs[0], &inputs[1]); + + if (image.dims().len() != 4) || (kernel.dims().len() != 4) { + return Err(TensorError::DimMismatch("deconv".to_string())); + } + + if stride.0 == 0 || stride.1 == 0 { + return Err(TensorError::DimMismatch( + "non-positive stride is not supported for deconv".to_string(), + )); + } + + if has_bias { + let bias = &inputs[2]; + if (bias.dims().len() != 1) || (bias.dims()[0] != kernel.dims()[0]) { + return Err(TensorError::DimMismatch("deconv bias".to_string())); + } + } + + let (kernel_height, kernel_width) = (kernel.dims()[2], kernel.dims()[3]); + + let mut expanded_image = intercalate_values(image, T::zero().unwrap(), stride.0, 2)?; + expanded_image = intercalate_values(&expanded_image, T::zero().unwrap(), stride.1, 3)?; + expanded_image = pad(&expanded_image, [(kernel_height - 1, kernel_width - 1); 2])?; + + // flip order + let channel_coord = (0..kernel.dims()[0]) + .cartesian_product(0..kernel.dims()[1]) + .collect::>(); + + let slice_coord = expanded_image + .dims() + .iter() + .enumerate() + .map(|(i, d)| { + if i == 2 { + padding[0].0..d - padding[1].0 + output_padding.0 + } else if i == 3 { + padding[0].1..d - padding[1].1 + output_padding.1 + } else { + 0..*d + } + }) + .collect::>(); + + let sliced_expanded_image = expanded_image.get_slice(&slice_coord)?; + + let mut inverted_kernels = vec![]; + + for (i, j) in channel_coord { + let mut channel = kernel.get_slice(&[i..i + 1, j..j + 1])?; + channel = Tensor::from(channel.clone().into_iter().rev()); + channel.reshape(&[kernel.dims()[2], kernel.dims()[3]])?; + inverted_kernels.push(channel); + } + + let mut deconv_kernel = + Tensor::new(Some(&inverted_kernels), &[inverted_kernels.len()])?.combine()?; + deconv_kernel.reshape(kernel.dims())?; + + // tensorflow formatting patch + if kernel.dims()[0] == sliced_expanded_image.dims()[1] { + deconv_kernel.reshape(&[ + kernel.dims()[1], + kernel.dims()[0], + kernel.dims()[2], + kernel.dims()[3], + ])?; + } + + let input = if has_bias { + vec![ + sliced_expanded_image, + deconv_kernel.clone(), + inputs[2].clone(), + ] + } else { + vec![sliced_expanded_image, deconv_kernel.clone()] + }; + + let output = conv(&input, [(0, 0); 2], (1, 1))?; + + Ok(output) +} + +/// Applies 2D sum pooling over a 4D tensor of shape B x C x H x W. +/// # Arguments +/// +/// * `image` - Tensor. +/// * `padding` - Tuple of padding values in x and y directions. +/// * `stride` - Tuple of stride values in x and y directions. +/// * `pool_dims` - Tuple of pooling window size in x and y directions. +/// # Examples +/// ``` +/// use ezkl::tensor::Tensor; +/// use ezkl::tensor::ops::sumpool; +/// use halo2_proofs::circuit::Value; +/// use halo2_proofs::plonk::Assigned; +/// use halo2curves::pasta::Fp as F; +/// +/// let x = Tensor::::new( +/// Some(&[5, 2, 3, 0, 4, -1, 3, 1, 6]), +/// &[1, 1, 3, 3], +/// ).unwrap(); +/// let pooled = sumpool::(&x, [(0, 0); 2], (1, 1), (2, 2)).unwrap(); +/// let expected: Tensor = Tensor::::new(Some(&[11, 8, 8, 10]), &[1, 1, 2, 2]).unwrap(); +/// assert_eq!(pooled, expected); +/// ``` +pub fn sumpool< + T: TensorType + Mul + Add + std::marker::Sync + std::marker::Send, +>( + image: &Tensor, + padding: [(usize, usize); 2], + stride: (usize, usize), + kernel_shape: (usize, usize), +) -> Result, TensorError> { + if image.dims().len() != 4 { + return Err(TensorError::DimMismatch("sumpool".to_string())); + } + let image_dims = image.dims(); + + let (batch, image_channels, image_height, image_width) = + (image_dims[0], image_dims[1], image_dims[2], image_dims[3]); + + let (output_channels, kernel_height, kernel_width) = + (image_channels, kernel_shape.0, kernel_shape.1); + + let padded_image = pad::(image, padding)?; + + let vert_slides = (image_height + padding[0].0 + padding[1].0 - kernel_height) / stride.0 + 1; + let horz_slides = (image_width + padding[0].1 + padding[1].1 - kernel_width) / stride.1 + 1; + + // calculate value of output + let mut output: Tensor = + Tensor::new(None, &[batch, output_channels, vert_slides, horz_slides]).unwrap(); + + let cartesian_coord = [ + (0..batch), + (0..output_channels), + (0..vert_slides), + (0..horz_slides), + ] + .iter() + .cloned() + .multi_cartesian_product() + .collect::>(); + + output + .par_iter_mut() + .enumerate() + .for_each(|(flat_index, o)| { + let coord = &cartesian_coord[flat_index]; + let (b, i, j, k) = (coord[0], coord[1], coord[2], coord[3]); + let rs = j * stride.0; + let cs = k * stride.1; + let thesum = sum(&padded_image + .get_slice(&[ + b..b + 1, + i..i + 1, + rs..(rs + kernel_height), + cs..(cs + kernel_width), + ]) + .unwrap()) + .unwrap(); + *o = thesum[0].clone(); + }); + + Ok(output) +} + +/// Applies 2D max pooling over a 4D tensor of shape B x C x H x W. +/// # Arguments +/// +/// * `image` - Tensor. +/// * `padding` - Tuple of padding values in x and y directions. +/// * `stride` - Tuple of stride values in x and y directions. +/// * `pool_dims` - Tuple of pooling window size in x and y directions. +/// # Examples +/// ``` +/// use ezkl::tensor::Tensor; +/// use ezkl::tensor::ops::max_pool2d; +/// use ezkl::circuit::utils::F32; +/// use halo2_proofs::circuit::Value; +/// use halo2_proofs::plonk::Assigned; +/// use halo2curves::pasta::Fp as F; +/// +/// +/// let x = Tensor::::new( +/// Some(&[5, 2, 3, 0, 4, -1, 3, 1, 6]), +/// &[1, 1, 3, 3], +/// ).unwrap(); +/// let pooled = max_pool2d::(&x, &[(0, 0); 2], &(1, 1), &(2, 2)).unwrap(); +/// let expected: Tensor = Tensor::::new(Some(&[5, 4, 4, 6]), &[1, 1, 2, 2]).unwrap(); +/// assert_eq!(pooled, expected); +/// +/// let x = Tensor::::new(Some(&[-0.9180, -0.4702, -0.0882, -0.0885, 0.3940, +/// -0.4884, 0.1395, 1.7860, -0.9729, 1.5160, -0.3346, +/// -0.0601, -0.1140, 0.2522, -0.2938, -0.0355]), &[1,1,4,4]).unwrap(); +/// let x = x.map(|x| F32(x)); +/// let pooled = max_pool2d::(&x, &[(0, 0); 2], &(2, 2), &(2, 2)).unwrap(); +/// let expected = Tensor::::new(Some(&[0.3940, 1.7860, 1.5160, -0.0355]), &[1, 1, 2, 2]).unwrap(); +/// let expected = expected.map(|x| F32(x)); +/// assert_eq!(pooled, expected); +/// ``` +pub fn max_pool2d( + image: &Tensor, + padding: &[(usize, usize); 2], + stride: &(usize, usize), + pool_dims: &(usize, usize), +) -> Result, TensorError> { + if image.dims().len() != 4 { + return Err(TensorError::DimMismatch("max_pool2d".to_string())); + } + let image_dims = image.dims(); + + let (batch, input_channels, image_height, image_width) = + (image_dims[0], image_dims[1], image_dims[2], image_dims[3]); + + let padded_image = pad::(image, *padding)?; + + let vert_slides = (image_height + padding[0].0 + padding[1].0 - pool_dims.0) / stride.0 + 1; + let horz_slides = (image_width + padding[0].1 + padding[1].1 - pool_dims.1) / stride.1 + 1; + + let mut output: Tensor = + Tensor::new(None, &[batch, input_channels, horz_slides, vert_slides]).unwrap(); + + let cartesian_coord = [ + (0..batch), + (0..input_channels), + (0..vert_slides), + (0..horz_slides), + ] + .iter() + .cloned() + .multi_cartesian_product() + .collect::>(); + + output + .par_iter_mut() + .enumerate() + .for_each(|(flat_index, o)| { + let coord = &cartesian_coord[flat_index]; + let (b, i, j, k) = (coord[0], coord[1], coord[2], coord[3]); + let rs = j * stride.0; + let cs = k * stride.1; + let themax = padded_image + .get_slice(&[ + b..(b + 1), + i..(i + 1), + rs..(rs + pool_dims.0), + cs..(cs + pool_dims.1), + ]) + .unwrap() + .into_iter() + .max() + .unwrap(); + *o = themax; + }); + + Ok(output) +} + +/// Dot product of two tensors. +/// # Arguments +/// +/// * `inputs` - Vector of tensors of length 2. +/// # Examples +/// ``` +/// use ezkl::tensor::Tensor; +/// use ezkl::tensor::ops::dot; +/// +/// let x = Tensor::::new( +/// Some(&[5, 2, 3, 0, 4, -1, 3, 1, 6]), +/// &[1, 3, 3], +/// ).unwrap(); +/// let y = Tensor::::new( +/// Some(&[5, 5, 10, -4, 2, -1, 2, 0, 1]), +/// &[1, 3, 3], +/// ).unwrap(); +/// assert_eq!(dot(&[x, y]).unwrap()[0], 86); +/// ``` +pub fn dot + Add + Send + Sync + std::iter::Sum>( + inputs: &[Tensor], +) -> Result, TensorError> { + if (inputs.len() != 2) || (inputs[0].clone().len() != inputs[1].clone().len()) { + return Err(TensorError::DimMismatch("dot".to_string())); + } + + let (a, b): (Tensor, Tensor) = (inputs[0].clone(), inputs[1].clone()); + let res: Vec = a + .par_iter() + .zip(b.par_iter()) + .fold( + || T::zero().unwrap(), + |acc, (k, i)| acc + k.clone() * i.clone(), + ) + .collect(); + + let res = res.into_iter().sum(); + + Tensor::new(Some(&[res]), &[1]) +} + +/// Pads a 4D tensor of shape `B x C x H x W` to a tensor of shape `B x C x (H + 2xPADDING) x (W + 2xPADDING)` using 0 values. +/// # Arguments +/// +/// * `image` - Tensor. +/// * `padding` - Tuple of padding values in x and y directions. +/// # Examples +/// ``` +/// use ezkl::tensor::Tensor; +/// use ezkl::tensor::ops::pad; +/// +/// let x = Tensor::::new( +/// Some(&[5, 2, 3, 0, 4, -1, 3, 1, 6]), +/// &[1, 1, 3, 3], +/// ).unwrap(); +/// let result = pad::(&x, [(1, 1); 2]).unwrap(); +/// let expected = Tensor::::new( +/// Some(&[0, 0, 0, 0, 0, 0, 5, 2, 3, 0, 0, 0, 4, -1, 0, 0, 3, 1, 6, 0, 0, 0, 0, 0, 0]), +/// &[1, 1, 5, 5], +/// ).unwrap(); +/// assert_eq!(result, expected); +/// +/// +/// +/// ``` +pub fn pad( + image: &Tensor, + padding: [(usize, usize); 2], +) -> Result, TensorError> { + if image.dims().len() != 4 { + return Err(TensorError::DimMismatch("pad".to_string())); + } + let (batch_size, channels, height, width) = ( + image.dims()[0], + image.dims()[1], + image.dims()[2], + image.dims()[3], + ); + + let (padding_before, padding_after) = padding.into(); + + let padded_height = height + padding_before.0 + padding_after.0; + let padded_width = width + padding_before.1 + padding_after.1; + + let mut output = + Tensor::::new(None, &[batch_size, channels, padded_height, padded_width]).unwrap(); + + for b in 0..batch_size { + for channel in 0..channels { + for row in 0..height { + for col in 0..width { + output.set( + &[b, channel, row + padding_before.0, col + padding_before.1], + image.get(&[b, channel, row, col]).clone(), + ); + } + } + } + } + + output.reshape(&[batch_size, channels, padded_height, padded_width])?; + Ok(output) +} + +/// Packs a multi-dim tensor into a single elem tensor +/// # Arguments +/// +/// * `a` - Tensor. +/// * `base` - Base to use when packing +/// * `scale` - fixed point representation scale +/// # Examples +/// ``` +/// use ezkl::tensor::Tensor; +/// use ezkl::tensor::ops::pack; +/// +/// let x = Tensor::::new( +/// Some(&[5, 2, 1]), +/// &[1, 3], +/// ).unwrap(); +/// let result = pack::(&x, 2, 2).unwrap(); +/// let expected = Tensor::::new( +/// Some(&[90]), +/// &[1], +/// ).unwrap(); +/// assert_eq!(result, expected); +/// ``` +pub fn pack( + a: &Tensor, + base: T, + scale: u32, +) -> Result, TensorError> +where + T: Add, + T: Mul, +{ + // base ^ (scale + tensor) + let mut output = T::zero().unwrap(); + let base_tensor = Tensor::new(Some(&[base]), &[1])?; + for (i, a_i) in a.iter().enumerate() { + let pow_value = &base_tensor.pow((i as u32) * (scale + 1))?[0]; + output = output + pow_value.clone() * a_i.clone(); + } + Tensor::new(Some(&[output]), &[1]) +} + +/// Concatenates a list of tensors along a specified axis. +/// # Arguments +/// * `inputs` - A slice of tensors to concatenate. +/// * `axis` - The axis along which to concatenate the tensors. +/// +/// # Examples +/// ``` +/// use ezkl::tensor::Tensor; +/// use ezkl::tensor::ops::concat; +/// // tested against pytorch outputs for reference :) +/// +/// // 1D example +/// let x = Tensor::::new(Some(&[1, 2, 3]), &[3]).unwrap(); +/// let y = Tensor::::new(Some(&[4, 5, 6]), &[3]).unwrap(); +/// let result = concat(&[&x, &y], 0).unwrap(); +/// let expected = Tensor::::new(Some(&[1, 2, 3, 4, 5, 6]), &[6]).unwrap(); +/// assert_eq!(result, expected); +/// +/// // 2D example +/// let x = Tensor::::new(Some(&[1, 2, 3, 4, 5, 6]), &[3, 2]).unwrap(); +/// let y = Tensor::::new(Some(&[7, 8, 9]), &[3, 1]).unwrap(); +/// let result = concat(&[&x, &y], 1).unwrap(); +/// let expected = Tensor::::new(Some(&[1, 2, 7, 3, 4, 8, 5, 6, 9]), &[3, 3]).unwrap(); +/// assert_eq!(result, expected); +/// +/// /// 4D example +/// let x = Tensor::::new(Some(&[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]), &[2, 2, 2, 2]).unwrap(); +/// let y = Tensor::::new(Some(&[17, 18, 19, 20, 21, 22, 23, 14]), &[2, 2, 1, 2]).unwrap(); +/// let result = concat(&[&x, &y], 2).unwrap(); +/// let expected = Tensor::::new(Some(&[1, 2, 3, 4, 17, 18, 5, 6, 7, 8, 19, 20, 9, 10, 11, 12, 21, 22, 13, 14, 15, 16, 23, 14]), &[2, 2, 3, 2]).unwrap(); +/// assert_eq!(result, expected); +/// +/// +/// // 5D example +/// let x = Tensor::::new(Some(&[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]), &[8, 1, 1, 1, 2]).unwrap(); +/// let y = Tensor::::new(Some(&[17, 18, 19, 20, 21, 22, 23, 14]), &[4, 1, 1, 1, 2]).unwrap(); +/// let result = concat(&[&x, &y], 0).unwrap(); +/// +/// let expected = Tensor::::new(Some(&[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 14]), &[12, 1, 1, 1, 2]).unwrap(); +/// assert_eq!(result, expected); +/// +/// ``` +/// +/// # Errors +/// Returns a TensorError if the tensors in `inputs` have incompatible dimensions for concatenation along the specified `axis`. + +pub fn concat( + inputs: &[&Tensor], + axis: usize, +) -> Result, TensorError> { + if inputs.len() == 1 { + return Ok(inputs[0].clone()); + } + + // Calculate the output tensor size + let mut output_size = inputs[0].dims().to_vec(); + output_size[axis] = inputs.iter().map(|x| x.dims()[axis]).sum(); + + // Allocate memory for the output tensor + let mut output = Tensor::new(None, &output_size)?; + let cartesian_coord = output_size + .iter() + .map(|x| 0..*x) + .multi_cartesian_product() + .collect::>(); + + let get_input_index = |index_along_axis: usize| -> (usize, usize) { + let mut current_idx = 0; + let mut input_idx = 0; + let mut input_coord_at_idx = 0; + for (i, elem) in inputs.iter().enumerate() { + current_idx += elem.dims()[axis]; + if index_along_axis < current_idx { + input_idx = i; + // subtract the current + input_coord_at_idx = index_along_axis - (current_idx - elem.dims()[axis]); + break; + } + } + (input_idx, input_coord_at_idx) + }; + + output = output.par_enum_map(|i, _: T| { + let coord = cartesian_coord[i].clone(); + let mut index = 0; + let mut input_index = 0; + let mut input_coord = coord.clone(); + for (j, x) in coord.iter().enumerate() { + if j == axis { + (input_index, input_coord[j]) = get_input_index(*x); + break; + } + index += x; + } + + Ok(inputs[input_index].get(&input_coord)) + })?; + + // Reshape the output tensor + output.reshape(&output_size)?; + + Ok(output) +} + +/// Slices a tensor from start to end along a given axis +/// +/// /// # Examples +/// ``` +/// // tested against pytorch output +/// use ezkl::tensor::Tensor; +/// use ezkl::tensor::ops::slice; +/// let x = Tensor::::new(Some(&[1, 2, 3, 4, 5, 6]), &[3, 2]).unwrap(); +/// let result = slice(&x, &0, &1, &2).unwrap(); +/// let expected = Tensor::::new(Some(&[3, 4]), &[1, 2]).unwrap(); +/// assert_eq!(result, expected); +/// +/// let x = Tensor::::new(Some(&[1, 2, 3, 4, 5, 6]), &[3, 2]).unwrap(); +/// let result = slice(&x, &1, &1, &2).unwrap(); +/// let expected = Tensor::::new(Some(&[2, 4, 6]), &[3, 1]).unwrap(); +/// assert_eq!(result, expected); +/// +/// let x = Tensor::::new(Some(&[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]), &[2, 2, 3]).unwrap(); +/// let result = slice(&x, &2, &1, &2).unwrap(); +/// let expected = Tensor::::new(Some(&[2, 5, 8, 11]), &[2, 2, 1]).unwrap(); +/// assert_eq!(result, expected); +/// ``` +/// +pub fn slice( + t: &Tensor, + axis: &usize, + start: &usize, + end: &usize, +) -> Result, TensorError> { + let mut slice = vec![]; + for (i, d) in t.dims().iter().enumerate() { + if i != *axis { + slice.push(0..*d) + } else { + slice.push(*start..*end) + } + } + + t.get_slice(&slice) +} + +// --------------------------------------------------------------------------------------------------------- +// -- nonlinear Functions --------------------------------------------------------------------------------- +// --------------------------------------------------------------------------------------------------------- +// --------------------------------------------------------------------------------------------------------- +// --------------------------------------------------------------------------------------------------------- +// --------------------------------------------------------------------------------------------------------- + +/// Activation functions +pub mod nonlinearities { + use super::*; + + /// Ceiling operator. + /// # Arguments + /// * `a` - Tensor + /// * `scale` - Single value + /// # Examples + /// ``` + /// use ezkl::tensor::Tensor; + /// + /// use ezkl::tensor::ops::nonlinearities::ceil; + /// let x = Tensor::::new( + /// Some(&[1, 2, 3, 4, 5, 6]), + /// &[3, 2], + /// ).unwrap(); + /// let result = ceil(&x, 2.0); + /// let expected = Tensor::::new(Some(&[1, 1, 2, 2, 3, 3]), &[3, 2]).unwrap(); + /// assert_eq!(result, expected); + /// ``` + pub fn ceil(a: &Tensor, scale: f64) -> Tensor { + a.par_enum_map(|_, a_i| { + let kix = (a_i as f64) / scale; + let rounded = kix.ceil(); + Ok::<_, TensorError>(rounded as i128) + }) + .unwrap() + } + + /// Floor operator. + /// # Arguments + /// * `a` - Tensor + /// * `scale` - Single value + /// # Examples + /// ``` + /// use ezkl::tensor::Tensor; + /// use ezkl::tensor::ops::nonlinearities::floor; + /// let x = Tensor::::new( + /// Some(&[1, 2, 3, 4, 5, 6]), + /// &[3, 2], + /// ).unwrap(); + /// let result = floor(&x, 2.0); + /// let expected = Tensor::::new(Some(&[0, 1, 1, 2, 2, 3]), &[3, 2]).unwrap(); + /// assert_eq!(result, expected); + /// ``` + pub fn floor(a: &Tensor, scale: f64) -> Tensor { + a.par_enum_map(|_, a_i| { + let kix = (a_i as f64) / scale; + let rounded = kix.floor(); + Ok::<_, TensorError>(rounded as i128) + }) + .unwrap() + } + + /// Round operator. + /// # Arguments + /// * `a` - Tensor + /// * `scale` - Single value + /// # Examples + /// ``` + /// use ezkl::tensor::Tensor; + /// use ezkl::tensor::ops::nonlinearities::round; + /// let x = Tensor::::new( + /// Some(&[1, 2, 3, 4, 5, 6]), + /// &[3, 2], + /// ).unwrap(); + /// let result = round(&x, 2.0); + /// let expected = Tensor::::new(Some(&[1, 1, 2, 2, 3, 3]), &[3, 2]).unwrap(); + /// assert_eq!(result, expected); + /// ``` + pub fn round(a: &Tensor, scale: f64) -> Tensor { + a.par_enum_map(|_, a_i| { + let kix = (a_i as f64) / scale; + let rounded = kix.round(); + Ok::<_, TensorError>(rounded as i128) + }) + .unwrap() + } + + /// Round half to even operator. + /// # Arguments + /// * `a` - Tensor + /// * `scale` - Single value + /// # Examples + /// ``` + /// use ezkl::tensor::Tensor; + /// use ezkl::tensor::ops::nonlinearities::round_half_to_even; + /// let x = Tensor::::new( + /// Some(&[1, 2, 3, 4, 5, 6]), + /// &[3, 2], + /// ).unwrap(); + /// let result = round_half_to_even(&x, 2.0); + /// let expected = Tensor::::new(Some(&[0, 1, 2, 2, 2, 3]), &[3, 2]).unwrap(); + /// assert_eq!(result, expected); + /// ``` + pub fn round_half_to_even(a: &Tensor, scale: f64) -> Tensor { + a.par_enum_map(|_, a_i| { + let kix = (a_i as f64) / scale; + let rounded = kix.round_ties_even(); + Ok::<_, TensorError>(rounded as i128) + }) + .unwrap() + } + + /// Raises to a floating point power. + /// # Arguments + /// * `a` - Tensor + /// * `power` - Floating point power + /// # Examples + /// ``` + /// use ezkl::tensor::Tensor; + /// use ezkl::tensor::ops::nonlinearities::pow; + /// let x = Tensor::::new( + /// Some(&[2, 15, 2, 1, 1, 0]), + /// &[2, 3], + /// ).unwrap(); + /// let result = pow(&x, 1.0, 2.0); + /// let expected = Tensor::::new(Some(&[4, 225, 4, 1, 1, 0]), &[2, 3]).unwrap(); + /// assert_eq!(result, expected); + /// ``` + pub fn pow(a: &Tensor, scale_input: f64, power: f64) -> Tensor { + a.par_enum_map(|_, a_i| { + let kix = (a_i as f64) / scale_input; + let kix = scale_input * (kix).powf(power); + let rounded = kix.round(); + Ok::<_, TensorError>(rounded as i128) + }) + .unwrap() + } + + /// Applies Kronecker delta to a tensor of integers. + /// # Arguments + /// * `a` - Tensor + /// # Examples + /// ``` + /// use ezkl::tensor::Tensor; + /// use ezkl::tensor::ops::nonlinearities::kronecker_delta; + /// let x = Tensor::::new( + /// Some(&[2, 15, 2, 1, 1, 0]), + /// &[2, 3], + /// ).unwrap(); + /// let result = kronecker_delta(&x); + /// let expected = Tensor::::new(Some(&[0, 0, 0, 0, 0, 1]), &[2, 3]).unwrap(); + /// assert_eq!(result, expected); + /// ``` + pub fn kronecker_delta( + a: &Tensor, + ) -> Tensor { + a.par_enum_map(|_, a_i| { + if a_i == T::zero().unwrap() { + Ok::<_, TensorError>(T::one().unwrap()) + } else { + Ok::<_, TensorError>(T::zero().unwrap()) + } + }) + .unwrap() + } + + /// Elementwise applies sigmoid to a tensor of integers. + /// # Arguments + /// + /// * `a` - Tensor + /// * `scale_input` - Single value + /// * `scale_output` - Single value + /// # Examples + /// ``` + /// use ezkl::tensor::Tensor; + /// use ezkl::tensor::ops::nonlinearities::sigmoid; + /// let x = Tensor::::new( + /// Some(&[2, 15, 2, 1, 1, 0]), + /// &[2, 3], + /// ).unwrap(); + /// let result = sigmoid(&x, 1.0); + /// let expected = Tensor::::new(Some(&[1, 1, 1, 1, 1, 1]), &[2, 3]).unwrap(); + /// + /// assert_eq!(result, expected); + /// let x = Tensor::::new( + /// Some(&[65536]), + /// &[1], + /// ).unwrap(); + /// let result = sigmoid(&x, 65536.0); + /// let expected = Tensor::::new(Some(&[47911]), &[1]).unwrap(); + /// assert_eq!(result, expected); + /// + /// /// assert_eq!(result, expected); + /// let x = Tensor::::new( + /// Some(&[256]), + /// &[1], + /// ).unwrap(); + /// let result = sigmoid(&x, 256.0); + /// let expected = Tensor::::new(Some(&[187]), &[1]).unwrap(); + /// + /// ``` + pub fn sigmoid(a: &Tensor, scale_input: f64) -> Tensor { + a.par_enum_map(|_, a_i| { + let kix = (a_i as f64) / scale_input; + let fout = scale_input / (1.0 + (-kix).exp()); + let rounded = fout.round(); + Ok::<_, TensorError>(rounded as i128) + }) + .unwrap() + } + + /// Elementwise applies exponential to a tensor of integers. + /// # Arguments + /// + /// * `a` - Tensor + /// * `scale_input` - Single value + /// * `scale_output` - Single value + /// # Examples + /// ``` + /// use ezkl::tensor::Tensor; + /// use ezkl::tensor::ops::nonlinearities::exp; + /// let x = Tensor::::new( + /// Some(&[2, 15, 2, 1, 1, 0]), + /// &[2, 3], + /// ).unwrap(); + /// let result = exp(&x, 1.0); + /// let expected = Tensor::::new(Some(&[7, 3269017, 7, 3, 3, 1]), &[2, 3]).unwrap(); + /// assert_eq!(result, expected); + /// + /// + /// let x = Tensor::::new( + /// Some(&[37, 12, 41]), + /// &[3], + /// ).unwrap(); + /// let result = exp(&x, 512.0); + /// + /// let expected = Tensor::::new(Some(&[550, 524, 555]), &[3]).unwrap(); + /// + /// assert_eq!(result, expected); + /// ``` + pub fn exp(a: &Tensor, scale_input: f64) -> Tensor { + a.par_enum_map(|_, a_i| { + let kix = (a_i as f64) / scale_input; + let fout = scale_input * kix.exp(); + let rounded = fout.round(); + Ok::<_, TensorError>(rounded as i128) + }) + .unwrap() + } + + /// Elementwise applies exponential to a tensor of integers. + /// # Arguments + /// + /// * `a` - Tensor + /// * `scale_input` - Single value + /// * `scale_output` - Single value + /// # Examples + /// ``` + /// use ezkl::tensor::Tensor; + /// use ezkl::tensor::ops::nonlinearities::ln; + /// let x = Tensor::::new( + /// Some(&[2, 15, 2, 1, 1, 3000]), + /// &[2, 3], + /// ).unwrap(); + /// let result = ln(&x, 1.0); + /// let expected = Tensor::::new(Some(&[1, 3, 1, 0, 0, 8]), &[2, 3]).unwrap(); + /// assert_eq!(result, expected); + /// + /// + /// let x = Tensor::::new( + /// Some(&[37, 12, 41]), + /// &[3], + /// ).unwrap(); + /// let result = ln(&x, 512.0); + /// + /// let expected = Tensor::::new(Some(&[-1345, -1922, -1293]), &[3]).unwrap(); + /// + /// assert_eq!(result, expected); + /// ``` + pub fn ln(a: &Tensor, scale_input: f64) -> Tensor { + a.par_enum_map(|_, a_i| { + let kix = (a_i as f64) / scale_input; + let fout = scale_input * kix.ln(); + let rounded = fout.round(); + Ok::<_, TensorError>(rounded as i128) + }) + .unwrap() + } + + /// Elementwise applies sign to a tensor of integers. + /// # Arguments + /// * `a` - Tensor + /// # Examples + /// ``` + /// use ezkl::tensor::Tensor; + /// use ezkl::tensor::ops::nonlinearities::sign; + /// let x = Tensor::::new( + /// Some(&[-2, 15, 2, 1, 1, 0]), + /// &[2, 3], + /// ).unwrap(); + /// let result = sign(&x); + /// let expected = Tensor::::new(Some(&[-1, 1, 1, 1, 1, 0]), &[2, 3]).unwrap(); + /// assert_eq!(result, expected); + /// ``` + pub fn sign(a: &Tensor) -> Tensor { + a.par_enum_map(|_, a_i| Ok::<_, TensorError>(a_i.signum())) + .unwrap() + } + + /// softmax layout + pub fn softmax_axes( + a: &Tensor, + scale: f64, + axes: &[usize], + ) -> (Tensor, Vec>) { + // we want this to be as small as possible so we set the output scale to 1 + let dims = a.dims(); + + if dims.len() == 1 { + return softmax(a, scale); + } + + let mut intermediate_values = vec![]; + + let cartesian_coord = dims[..dims.len() - 1] + .iter() + .map(|x| 0..*x) + .multi_cartesian_product() + .collect::>(); + + let mut outputs = vec![]; + + for coord in cartesian_coord { + let mut sum_dims = vec![]; + for (i, c) in coord.iter().enumerate() { + if axes.contains(&i) { + sum_dims.push(0..a.dims()[i]); + } else { + sum_dims.push(*c..*c + 1); + } + } + + let softmax_input = a.get_slice(&sum_dims).unwrap(); + + let res = softmax(&softmax_input, scale); + + outputs.push(res.0); + intermediate_values.extend(res.1); + } + + let mut res = Tensor::new(Some(&outputs), &[outputs.len()]) + .unwrap() + .combine() + .unwrap(); + res.reshape(dims).unwrap(); + (res, intermediate_values) + } + + /// Applies softmax + /// # Arguments + /// + /// * `a` - Tensor + /// * `scale_input` - Single value + /// * `scale_output` - Single value + /// # Examples + /// ``` + /// use ezkl::tensor::Tensor; + /// use ezkl::tensor::ops::nonlinearities::softmax; + /// let x = Tensor::::new( + /// Some(&[2, 2, 3, 2, 2, 0]), + /// &[2, 3], + /// ).unwrap(); + /// let result = softmax(&x, 128.0).0; + /// // doubles the scale of the input + /// let expected = Tensor::::new(Some(&[2730, 2730, 2751, 2730, 2730, 2688]), &[2, 3]).unwrap(); + /// assert_eq!(result, expected); + /// ``` + pub fn softmax(a: &Tensor, scale: f64) -> (Tensor, Vec>) { + // the more accurate calculation is commented out and we implement as below so it matches the steps in layout + let mut intermediate_values = vec![]; + + intermediate_values.push(a.clone()); + + let exp = exp(a, scale); + + let sum = sum(&exp).unwrap(); + intermediate_values.push(sum.clone()); + let inv_denom = recip(&sum, scale.powf(2.0)); + + ((exp * inv_denom).unwrap(), intermediate_values) + } + + /// Applies range_check_percent + /// # Arguments + /// + /// * `a` - Tensor + /// * `b` - Tensor + /// * `scale_input` - Single value + /// * `scale_output` - Single value + /// # Examples + /// ``` + /// use ezkl::tensor::Tensor; + /// use ezkl::tensor::ops::nonlinearities::range_check_percent; + /// let x = Tensor::::new( + /// Some(&[100, 200, 300, 400, 500, 600]), + /// &[2, 3], + /// ).unwrap(); + /// let y = Tensor::::new( + /// Some(&[103, 204, 303, 404, 505, 607]), + /// &[2, 3], + /// ).unwrap(); + /// let result = range_check_percent(&[x, y], 1024, 1024, 1.0); // 1% tolerance + /// let expected = Tensor::::new(Some(&[1, 1, 0, 0, 0, 1]), &[2, 3]).unwrap(); + /// assert_eq!(result, expected); + /// ``` + pub fn range_check_percent( + t: &[Tensor], + input_scale: usize, + output_scale: usize, + tol: f32, + ) -> Tensor { + // the more accurate calculation is commented out and we implement as below so it matches the steps in layout + let scale = input_scale * output_scale; + let diff: Tensor = sub(t).unwrap(); + let recip = recip(&t[0], scale as f64); + let product = mult(&[diff, recip]).unwrap(); + let _tol = ((tol / 100.0) * scale as f32).round() as f64; + let upper_bound = greater_than(&product, _tol); + let neg_product = + mult(&[product, Tensor::::new(Some(&[-1]), &[1]).unwrap()]).unwrap(); + let lower_bound = greater_than(&neg_product, _tol); + + add(&[upper_bound, lower_bound]).unwrap() + } + + /// Elementwise applies square root to a tensor of integers. + /// # Arguments + /// + /// * `a` - Tensor + /// * `scale_input` - Single value + /// * `scale_output` - Single value + /// # Examples + /// ``` + /// use ezkl::tensor::Tensor; + /// use ezkl::tensor::ops::nonlinearities::sqrt; + /// let x = Tensor::::new( + /// Some(&[4, 25, 8, 1, 1, 0]), + /// &[2, 3], + /// ).unwrap(); + /// let result = sqrt(&x, 1.0); + /// let expected = Tensor::::new(Some(&[2, 5, 3, 1, 1, 0]), &[2, 3]).unwrap(); + /// assert_eq!(result, expected); + /// ``` + pub fn sqrt(a: &Tensor, scale_input: f64) -> Tensor { + a.par_enum_map(|_, a_i| { + let kix = (a_i as f64) / scale_input; + let fout = scale_input * kix.sqrt(); + let rounded = fout.round(); + Ok::<_, TensorError>(rounded as i128) + }) + .unwrap() + } + + /// Elementwise applies reciprocal square root to a tensor of integers. + /// # Arguments + /// + /// * `a` - Tensor + /// * `scale_input` - Single value + /// * `scale_output` - Single value + /// # Examples + /// ``` + /// use ezkl::tensor::Tensor; + /// use ezkl::tensor::ops::nonlinearities::rsqrt; + /// let x = Tensor::::new( + /// Some(&[4, 25, 8, 1, 1, 1]), + /// &[2, 3], + /// ).unwrap(); + /// let result = rsqrt(&x, 1.0); + /// let expected = Tensor::::new(Some(&[1, 0, 0, 1, 1, 1]), &[2, 3]).unwrap(); + /// assert_eq!(result, expected); + /// ``` + pub fn rsqrt(a: &Tensor, scale_input: f64) -> Tensor { + a.par_enum_map(|_, a_i| { + let kix = (a_i as f64) / scale_input; + let fout = scale_input / (kix.sqrt() + f64::EPSILON); + let rounded = fout.round(); + Ok::<_, TensorError>(rounded as i128) + }) + .unwrap() + } + + /// Elementwise applies cosine to a tensor of integers. + /// # Arguments + /// * `a` - Tensor + /// * `scale_input` - Single value + /// * `scale_output` - Single value + /// # Examples + /// ``` + /// use ezkl::tensor::Tensor; + /// use ezkl::tensor::ops::nonlinearities::cos; + /// let x = Tensor::::new( + /// Some(&[4, 25, 8, 1, 1, 0]), + /// &[2, 3], + /// ).unwrap(); + /// let result = cos(&x, 2.0); + /// let expected = Tensor::::new(Some(& [-1, 2, -1, 2, 2, 2]), &[2, 3]).unwrap(); + /// assert_eq!(result, expected); + /// ``` + pub fn cos(a: &Tensor, scale_input: f64) -> Tensor { + a.par_enum_map(|_, a_i| { + let kix = (a_i as f64) / scale_input; + let fout = scale_input * kix.cos(); + let rounded = fout.round(); + Ok::<_, TensorError>(rounded as i128) + }) + .unwrap() + } + + /// Elementwise applies arccosine to a tensor of integers. + /// # Arguments + /// * `a` - Tensor + /// * `scale_input` - Single value + /// * `scale_output` - Single value + /// # Examples + /// ``` + /// use ezkl::tensor::Tensor; + /// use ezkl::tensor::ops::nonlinearities::acos; + /// let x = Tensor::::new( + /// Some(&[4, 25, 8, 1, 1, 0]), + /// &[2, 3], + /// ).unwrap(); + /// let result = acos(&x, 1.0); + /// let expected = Tensor::::new(Some(&[0, 0, 0, 0, 0, 2]), &[2, 3]).unwrap(); + /// assert_eq!(result, expected); + /// ``` + pub fn acos(a: &Tensor, scale_input: f64) -> Tensor { + a.par_enum_map(|_, a_i| { + let kix = (a_i as f64) / scale_input; + let fout = scale_input * kix.acos(); + let rounded = fout.round(); + Ok::<_, TensorError>(rounded as i128) + }) + .unwrap() + } + + /// Elementwise applies cosh to a tensor of integers. + /// # Arguments + /// * `a` - Tensor + /// * `scale_input` - Single value + /// * `scale_output` - Single value + /// # Examples + /// ``` + /// use ezkl::tensor::Tensor; + /// use ezkl::tensor::ops::nonlinearities::cosh; + /// let x = Tensor::::new( + /// Some(&[4, 25, 8, 1, 1, 0]), + /// &[2, 3], + /// ).unwrap(); + /// let result = cosh(&x, 1.0); + /// let expected = Tensor::::new(Some(&[27, 36002449669, 1490, 2, 2, 1]), &[2, 3]).unwrap(); + /// assert_eq!(result, expected); + /// ``` + pub fn cosh(a: &Tensor, scale_input: f64) -> Tensor { + a.par_enum_map(|_, a_i| { + let kix = (a_i as f64) / scale_input; + let fout = scale_input * kix.cosh(); + let rounded = fout.round(); + Ok::<_, TensorError>(rounded as i128) + }) + .unwrap() + } + + /// Elementwise applies arccosineh to a tensor of integers. + /// # Arguments + /// * `a` - Tensor + /// * `scale_input` - Single value + /// * `scale_output` - Single value + /// # Examples + /// ``` + /// use ezkl::tensor::Tensor; + /// use ezkl::tensor::ops::nonlinearities::acosh; + /// let x = Tensor::::new( + /// Some(&[4, 25, 8, 1, 1, 0]), + /// &[2, 3], + /// ).unwrap(); + /// let result = acosh(&x, 1.0); + /// let expected = Tensor::::new(Some(& [2, 4, 3, 0, 0, 0]), &[2, 3]).unwrap(); + /// assert_eq!(result, expected); + /// ``` + pub fn acosh(a: &Tensor, scale_input: f64) -> Tensor { + a.par_enum_map(|_, a_i| { + let kix = (a_i as f64) / scale_input; + let fout = scale_input * kix.acosh(); + let rounded = fout.round(); + Ok::<_, TensorError>(rounded as i128) + }) + .unwrap() + } + + /// Elementwise applies sine to a tensor of integers. + /// # Arguments + /// * `a` - Tensor + /// * `scale_input` - Single value + /// * `scale_output` - Single value + /// # Examples + /// ``` + /// use ezkl::tensor::Tensor; + /// use ezkl::tensor::ops::nonlinearities::sin; + /// let x = Tensor::::new( + /// Some(&[4, 25, 8, 1, 1, 0]), + /// &[2, 3], + /// ).unwrap(); + /// let result = sin(&x, 128.0); + /// let expected = Tensor::::new(Some(&[4, 25, 8, 1, 1, 0]), &[2, 3]).unwrap(); + /// assert_eq!(result, expected); + /// ``` + pub fn sin(a: &Tensor, scale_input: f64) -> Tensor { + a.par_enum_map(|_, a_i| { + let kix = (a_i as f64) / scale_input; + let fout = scale_input * kix.sin(); + let rounded = fout.round(); + Ok::<_, TensorError>(rounded as i128) + }) + .unwrap() + } + + /// Elementwise applies arcsine to a tensor of integers. + /// # Arguments + /// * `a` - Tensor + /// * `scale_input` - Single value + /// * `scale_output` - Single value + /// # Examples + /// ``` + /// use ezkl::tensor::Tensor; + /// use ezkl::tensor::ops::nonlinearities::asin; + /// let x = Tensor::::new( + /// Some(&[4, 25, 8, 1, 1, 0]), + /// &[2, 3], + /// ).unwrap(); + /// let result = asin(&x, 128.0); + /// let expected = Tensor::::new(Some(& [4, 25, 8, 1, 1, 0]), &[2, 3]).unwrap(); + /// assert_eq!(result, expected); + /// ``` + pub fn asin(a: &Tensor, scale_input: f64) -> Tensor { + a.par_enum_map(|_, a_i| { + let kix = (a_i as f64) / scale_input; + let fout = scale_input * kix.asin(); + let rounded = fout.round(); + Ok::<_, TensorError>(rounded as i128) + }) + .unwrap() + } + + /// Elementwise applies sineh to a tensor of integers. + /// # Arguments + /// * `a` - Tensor + /// * `scale_input` - Single value + /// * `scale_output` - Single value + /// # Examples + /// ``` + /// use ezkl::tensor::Tensor; + /// use ezkl::tensor::ops::nonlinearities::sinh; + /// let x = Tensor::::new( + /// Some(&[4, 25, 8, 1, 1, 0]), + /// &[2, 3], + /// ).unwrap(); + /// let result = sinh(&x, 2.0); + /// let expected = Tensor::::new(Some(&[7, 268337, 55, 1, 1, 0]), &[2, 3]).unwrap(); + /// assert_eq!(result, expected); + /// ``` + pub fn sinh(a: &Tensor, scale_input: f64) -> Tensor { + a.par_enum_map(|_, a_i| { + let kix = (a_i as f64) / scale_input; + let fout = scale_input * kix.sinh(); + let rounded = fout.round(); + Ok::<_, TensorError>(rounded as i128) + }) + .unwrap() + } + + /// Elementwise applies arcsineh to a tensor of integers. + /// # Arguments + /// * `a` - Tensor + /// * `scale_input` - Single value + /// * `scale_output` - Single value + /// # Examples + /// ``` + /// use ezkl::tensor::Tensor; + /// use ezkl::tensor::ops::nonlinearities::asinh; + /// let x = Tensor::::new( + /// Some(&[4, 25, 8, 1, 1, 0]), + /// &[2, 3], + /// ).unwrap(); + /// let result = asinh(&x, 128.0); + /// let expected = Tensor::::new(Some(&[4, 25, 8, 1, 1, 0]), &[2, 3]).unwrap(); + /// assert_eq!(result, expected); + /// ``` + pub fn asinh(a: &Tensor, scale_input: f64) -> Tensor { + a.par_enum_map(|_, a_i| { + let kix = (a_i as f64) / scale_input; + let fout = scale_input * kix.asinh(); + let rounded = fout.round(); + Ok::<_, TensorError>(rounded as i128) + }) + .unwrap() + } + + /// Elementwise applies tan activation to a tensor of integers. + /// # Arguments + /// * `a` - Tensor + /// * `scale_input` - Single value + /// * `scale_output` - Single value + /// # Examples + /// ``` + /// use ezkl::tensor::Tensor; + /// use ezkl::tensor::ops::nonlinearities::tan; + /// let x = Tensor::::new( + /// Some(&[4, 25, 8, 1, 1, 0]), + /// &[2, 3], + /// ).unwrap(); + /// let result = tan(&x, 64.0); + /// let expected = Tensor::::new(Some(&[4, 26, 8, 1, 1, 0]), &[2, 3]).unwrap(); + /// assert_eq!(result, expected); + /// ``` + pub fn tan(a: &Tensor, scale_input: f64) -> Tensor { + a.par_enum_map(|_, a_i| { + let kix = (a_i as f64) / scale_input; + let fout = scale_input * kix.tan(); + let rounded = fout.round(); + Ok::<_, TensorError>(rounded as i128) + }) + .unwrap() + } + + /// Elementwise applies arctan activation to a tensor of integers. + /// # Arguments + /// * `a` - Tensor + /// * `scale_input` - Single value + /// * `scale_output` - Single value + /// # Examples + /// ``` + /// use ezkl::tensor::Tensor; + /// use ezkl::tensor::ops::nonlinearities::atan; + /// let x = Tensor::::new( + /// Some(&[4, 25, 8, 1, 1, 0]), + /// &[2, 3], + /// ).unwrap(); + /// let result = atan(&x, 128.0); + /// let expected = Tensor::::new(Some(&[4, 25, 8, 1, 1, 0]), &[2, 3]).unwrap(); + /// assert_eq!(result, expected); + /// ``` + pub fn atan(a: &Tensor, scale_input: f64) -> Tensor { + a.par_enum_map(|_, a_i| { + let kix = (a_i as f64) / scale_input; + let fout = scale_input * kix.atan(); + let rounded = fout.round(); + Ok::<_, TensorError>(rounded as i128) + }) + .unwrap() + } + + /// Elementwise applies tanh activation to a tensor of integers. + /// # Arguments + /// + /// * `a` - Tensor + /// * `scale_input` - Single value + /// * `scale_output` - Single value + /// # Examples + /// ``` + /// use ezkl::tensor::Tensor; + /// use ezkl::tensor::ops::nonlinearities::tanh; + /// let x = Tensor::::new( + /// Some(&[4, 25, 8, 1, 1, 0]), + /// &[2, 3], + /// ).unwrap(); + /// let result = tanh(&x, 128.0); + /// let expected = Tensor::::new(Some(&[4, 25, 8, 1, 1, 0]), &[2, 3]).unwrap(); + /// assert_eq!(result, expected); + /// ``` + + pub fn tanh(a: &Tensor, scale_input: f64) -> Tensor { + a.par_enum_map(|_, a_i| { + let kix = (a_i as f64) / scale_input; + let fout = scale_input * kix.tanh(); + let rounded = fout.round(); + Ok::<_, TensorError>(rounded as i128) + }) + .unwrap() + } + + /// Elementwise applies arctanh activation to a tensor of integers. + /// # Arguments + /// + /// * `a` - Tensor + /// * `scale_input` - Single value + /// * `scale_output` - Single value + /// # Examples + /// ``` + /// use ezkl::tensor::Tensor; + /// use ezkl::tensor::ops::nonlinearities::atanh; + /// let x = Tensor::::new( + /// Some(&[4, 25, 8, 2, 2, 0]), + /// &[2, 3], + /// ).unwrap(); + /// let result = atanh(&x, 32.0); + /// let expected = Tensor::::new(Some(&[4, 34, 8, 2, 2, 0]), &[2, 3]).unwrap(); + /// assert_eq!(result, expected); + /// ``` + + pub fn atanh(a: &Tensor, scale_input: f64) -> Tensor { + a.par_enum_map(|_, a_i| { + let kix = (a_i as f64) / scale_input; + let fout = scale_input * kix.atanh(); + let rounded = fout.round(); + Ok::<_, TensorError>(rounded as i128) + }) + .unwrap() + } + + /// Applies error function (erf) on a tensor of integers. + /// # Arguments + /// + /// * `a` - Tensor + /// * `scale_input` - Single value + /// * `scale_output` - Single value + /// # Examples + /// ``` + /// use ezkl::tensor::Tensor; + /// use ezkl::tensor::ops::nonlinearities::erffunc; + /// let x = Tensor::::new( + /// Some(&[5, 28, 9, 1, 1, 0]), + /// &[2, 3], + /// ).unwrap(); + /// let result = erffunc(&x, 128.0); + /// let expected = Tensor::::new(Some(&[6, 31, 10, 1, 1, 0]), &[2, 3]).unwrap(); + /// assert_eq!(result, expected); + /// ``` + pub fn erffunc(a: &Tensor, scale_input: f64) -> Tensor { + const NCOEF: usize = 28; + const COF: [f64; 28] = [ + -1.3026537197817094, + 6.419_697_923_564_902e-1, + 1.9476473204185836e-2, + -9.561_514_786_808_63e-3, + -9.46595344482036e-4, + 3.66839497852761e-4, + 4.2523324806907e-5, + -2.0278578112534e-5, + -1.624290004647e-6, + 1.303655835580e-6, + 1.5626441722e-8, + -8.5238095915e-8, + 6.529054439e-9, + 5.059343495e-9, + -9.91364156e-10, + -2.27365122e-10, + 9.6467911e-11, + 2.394038e-12, + -6.886027e-12, + 8.94487e-13, + 3.13092e-13, + -1.12708e-13, + 3.81e-16, + 7.106e-15, + -1.523e-15, + -9.4e-17, + 1.21e-16, + -2.8e-17, + ]; + + /// Chebyshev coefficients + fn erfccheb(z: f64) -> f64 { + let mut d = 0f64; + let mut dd = 0f64; + + assert!(z >= 0f64, "erfccheb requires nonnegative argument"); + let t = 2f64 / (2f64 + z); + let ty = 4f64 * t - 2f64; + for j in (1..NCOEF - 1).rev() { + let tmp = d; + d = ty * d - dd + COF[j]; + dd = tmp; + } + t * (-z.powi(2) + 0.5 * (COF[0] + ty * d) - dd).exp() + } + + pub fn erf(x: f64) -> f64 { + if x >= 0f64 { + 1.0 - erfccheb(x) + } else { + erfccheb(-x) - 1f64 + } + } + + a.par_enum_map(|_, a_i| { + let kix = (a_i as f64) / scale_input; + let fout = scale_input * erf(kix); + let rounded = fout.round(); + Ok::<_, TensorError>(rounded as i128) + }) + .unwrap() + } + + /// Elementwise applies leaky relu to a tensor of integers. + /// # Arguments + /// + /// * `a` - Tensor + /// * `scale` - Single value + /// * `slope` - Single value + /// # Examples + /// ``` + /// use ezkl::tensor::Tensor; + /// use ezkl::tensor::ops::nonlinearities::leakyrelu; + /// let x = Tensor::::new( + /// Some(&[2, 15, 2, 1, 1, -5]), + /// &[2, 3], + /// ).unwrap(); + /// let result = leakyrelu(&x, 0.1); + /// let expected = Tensor::::new(Some(&[2, 15, 2, 1, 1, -1]), &[2, 3]).unwrap(); + /// assert_eq!(result, expected); + /// ``` + pub fn leakyrelu(a: &Tensor, slope: f64) -> Tensor { + a.par_enum_map(|_, a_i| { + let rounded = if a_i < 0 { + let d_inv_x = (slope) * (a_i as f64); + d_inv_x.round() as i128 + } else { + let d_inv_x = a_i as f64; + d_inv_x.round() as i128 + }; + Ok::<_, TensorError>(rounded) + }) + .unwrap() + } + + /// Elementwise applies max to a tensor of integers. + /// # Arguments + /// * `a` - Tensor + /// * `b` - scalar + /// # Examples + /// ``` + /// use ezkl::tensor::Tensor; + /// use ezkl::tensor::ops::nonlinearities::max; + /// let x = Tensor::::new( + /// Some(&[2, 15, 2, 1, 1, -5]), + /// &[2, 3], + /// ).unwrap(); + /// let result = max(&x, 1, 1, 1.0); + /// let expected = Tensor::::new(Some(&[2, 15, 2, 1, 1, 1]), &[2, 3]).unwrap(); + /// assert_eq!(result, expected); + /// ``` + pub fn max( + a: &Tensor, + in_scale: usize, + out_scale: usize, + threshold: f64, + ) -> Tensor { + // calculate value of output + a.par_enum_map(|_, a_i| { + let d_inv_x = (a_i as f64) / (in_scale as f64); + let rounded = if d_inv_x <= threshold { + (threshold * (out_scale as f64)).round() as i128 + } else { + (d_inv_x * (out_scale as f64)).round() as i128 + }; + Ok::<_, TensorError>(rounded) + }) + .unwrap() + } + + /// Elementwise applies min to a tensor of integers. + /// # Arguments + /// * `a` - Tensor + /// * `b` - scalar + /// # Examples + /// ``` + /// use ezkl::tensor::Tensor; + /// use ezkl::tensor::ops::nonlinearities::min; + /// let x = Tensor::::new( + /// Some(&[2, 15, 2, 1, 1, -5]), + /// &[2, 3], + /// ).unwrap(); + /// let result = min(&x, 1, 1, 2.0); + /// let expected = Tensor::::new(Some(&[2, 2, 2, 1, 1, -5]), &[2, 3]).unwrap(); + /// assert_eq!(result, expected); + /// ``` + pub fn min( + a: &Tensor, + in_scale: usize, + out_scale: usize, + threshold: f64, + ) -> Tensor { + // calculate value of output + a.par_enum_map(|_, a_i| { + let d_inv_x = (a_i as f64) / (in_scale as f64); + let rounded = if d_inv_x >= threshold { + (threshold * (out_scale as f64)).round() as i128 + } else { + (d_inv_x * (out_scale as f64)).round() as i128 + }; + Ok::<_, TensorError>(rounded) + }) + .unwrap() + } + + /// Elementwise divides a tensor with a const integer element. + /// # Arguments + /// + /// * `a` - Tensor + /// * `b` - Single value + /// # Examples + /// ``` + /// use ezkl::tensor::Tensor; + /// use ezkl::tensor::ops::nonlinearities::const_div; + /// let x = Tensor::::new( + /// Some(&[2, 1, 2, 7, 1, 1]), + /// &[2, 3], + /// ).unwrap(); + /// let k = 2.0; + /// let result = const_div(&x, k); + /// let expected = Tensor::::new(Some(&[1, 1, 1, 4, 1, 1]), &[2, 3]).unwrap(); + /// assert_eq!(result, expected); + /// ``` + pub fn const_div(a: &Tensor, denom: f64) -> Tensor { + a.par_enum_map(|_, a_i| { + let d_inv_x = (a_i as f64) / (denom); + Ok::<_, TensorError>(d_inv_x.round() as i128) + }) + .unwrap() + } + + /// Elementwise inverse. + /// # Arguments + /// + /// * `a` - Tensor + /// * `b` - Single value + /// # Examples + /// ``` + /// use ezkl::tensor::Tensor; + /// use ezkl::tensor::ops::nonlinearities::recip; + /// let x = Tensor::::new( + /// Some(&[2, 1, 2, 7, 1, 1]), + /// &[2, 3], + /// ).unwrap(); + /// let k = 2_f64; + /// let result = recip(&x, k); + /// let expected = Tensor::::new(Some(&[1, 2, 1, 0, 2, 2]), &[2, 3]).unwrap(); + /// assert_eq!(result, expected); + /// ``` + pub fn recip(a: &Tensor, scale: f64) -> Tensor { + a.par_enum_map(|_, a_i| { + let denom = (1_f64) / (a_i as f64 + f64::EPSILON); + let d_inv_x = scale * denom; + Ok::<_, TensorError>(d_inv_x.round() as i128) + }) + .unwrap() + } + + /// Elementwise greater than + /// # Arguments + /// + /// * `a` - Tensor + /// * `b` - Single value + /// # Examples + /// ``` + /// use ezkl::tensor::Tensor; + /// use ezkl::tensor::ops::nonlinearities::greater_than; + /// let x = Tensor::::new( + /// Some(&[2, 1, 2, 7, 1, 1]), + /// &[2, 3], + /// ).unwrap(); + /// let k = 2.0; + /// let result = greater_than(&x, k); + /// let expected = Tensor::::new(Some(&[0, 0, 0, 1, 0, 0]), &[2, 3]).unwrap(); + /// assert_eq!(result, expected); + /// ``` + pub fn greater_than(a: &Tensor, b: f64) -> Tensor { + a.par_enum_map(|_, a_i| Ok::<_, TensorError>(i128::from((a_i as f64 - b) > 0_f64))) + .unwrap() + } + + /// Elementwise greater than + /// # Arguments + /// + /// * `a` - Tensor + /// * `b` - Single value + /// # Examples + /// ``` + /// use ezkl::tensor::Tensor; + /// use ezkl::tensor::ops::nonlinearities::greater_than_equal; + /// let x = Tensor::::new( + /// Some(&[2, 1, 2, 7, 1, 1]), + /// &[2, 3], + /// ).unwrap(); + /// let k = 2.0; + /// let result = greater_than_equal(&x, k); + /// let expected = Tensor::::new(Some(&[1, 0, 1, 1, 0, 0]), &[2, 3]).unwrap(); + /// assert_eq!(result, expected); + /// ``` + pub fn greater_than_equal(a: &Tensor, b: f64) -> Tensor { + a.par_enum_map(|_, a_i| Ok::<_, TensorError>(i128::from((a_i as f64 - b) >= 0_f64))) + .unwrap() + } + + /// Elementwise less than + /// # Arguments + /// * `a` - Tensor + /// * `b` - Single value + /// # Examples + /// ``` + /// use ezkl::tensor::Tensor; + /// use ezkl::tensor::ops::nonlinearities::less_than; + /// + /// let x = Tensor::::new( + /// Some(&[2, 1, 2, 7, 1, 1]), + /// &[2, 3], + /// ).unwrap(); + /// let k = 2.0; + /// + /// let result = less_than(&x, k); + /// let expected = Tensor::::new(Some(&[0, 1, 0, 0, 1, 1]), &[2, 3]).unwrap(); + /// assert_eq!(result, expected); + /// ``` + pub fn less_than(a: &Tensor, b: f64) -> Tensor { + a.par_enum_map(|_, a_i| Ok::<_, TensorError>(i128::from((a_i as f64 - b) < 0_f64))) + .unwrap() + } + + /// Elementwise less than + /// # Arguments + /// * `a` - Tensor + /// * `b` - Single value + /// # Examples + /// ``` + /// use ezkl::tensor::Tensor; + /// use ezkl::tensor::ops::nonlinearities::less_than_equal; + /// + /// let x = Tensor::::new( + /// Some(&[2, 1, 2, 7, 1, 1]), + /// &[2, 3], + /// ).unwrap(); + /// let k = 2.0; + /// + /// let result = less_than_equal(&x, k); + /// let expected = Tensor::::new(Some(&[1, 1, 1, 0, 1, 1]), &[2, 3]).unwrap(); + /// assert_eq!(result, expected); + /// ``` + pub fn less_than_equal(a: &Tensor, b: f64) -> Tensor { + a.par_enum_map(|_, a_i| Ok::<_, TensorError>(i128::from((a_i as f64 - b) <= 0_f64))) + .unwrap() + } + + /// Takes the mean of a tensor + /// # Arguments + /// + /// * `a` - Tensor + /// # Examples + /// ``` + /// use ezkl::tensor::Tensor; + /// use ezkl::tensor::ops::nonlinearities::mean; + /// let x = Tensor::::new( + /// Some(&[2, 1, 2, 7, 1, 1]), + /// &[2, 3], + /// ).unwrap(); + /// let result = mean(&x, 1); + /// let expected = Tensor::::new(Some(&[2]), &[1]).unwrap(); + /// assert_eq!(result, expected); + /// ``` + pub fn mean(a: &Tensor, scale: usize) -> Tensor { + let sum = sum(a).unwrap(); + const_div(&sum, (scale * a.len()) as f64) + } +} + +/// Ops that return the transcript i.e intermediate calcs of an op +pub mod accumulated { + use super::*; + + /// Dot product of two tensors. + /// # Arguments + /// + /// * `inputs` - Vector of tensors of length 2. + /// # Examples + /// ``` + /// use ezkl::tensor::Tensor; + /// use ezkl::tensor::ops::accumulated::dot; + /// + /// let x = Tensor::::new( + /// Some(&[5, 2]), + /// &[2], + /// ).unwrap(); + /// let y = Tensor::::new( + /// Some(&[5, 5]), + /// &[2], + /// ).unwrap(); + /// let expected = Tensor::::new( + /// Some(&[25, 35]), + /// &[2], + /// ).unwrap(); + /// assert_eq!(dot(&[x, y], 1).unwrap(), expected); + /// ``` + pub fn dot + Add>( + inputs: &[Tensor; 2], + chunk_size: usize, + ) -> Result, TensorError> { + if inputs[0].clone().len() != inputs[1].clone().len() { + return Err(TensorError::DimMismatch("dot".to_string())); + } + let (a, b): (Tensor, Tensor) = (inputs[0].clone(), inputs[1].clone()); + + let transcript: Tensor = a + .iter() + .zip(b) + .chunks(chunk_size) + .into_iter() + .scan(T::zero().unwrap(), |acc, chunk| { + let k = chunk.fold(T::zero().unwrap(), |acc, (a_i, b_i)| { + acc.clone() + a_i.clone() * b_i.clone() + }); + *acc = acc.clone() + k.clone(); + Some(acc.clone()) + }) + .collect(); + + Ok(transcript) + } + + /// Sums a tensor. + /// # Arguments + /// + /// * `a` - Tensor + /// # Examples + /// ``` + /// use ezkl::tensor::Tensor; + /// use ezkl::tensor::ops::accumulated::sum; + /// let x = Tensor::::new( + /// Some(&[2, 15, 2, 1, 1, 0]), + /// &[2, 3], + /// ).unwrap(); + /// let result = sum(&x, 1).unwrap(); + /// let expected = Tensor::::new( + /// Some(&[2, 17, 19, 20, 21, 21]), + /// &[6], + /// ).unwrap(); + /// assert_eq!(result, expected); + /// ``` + pub fn sum + Add>( + a: &Tensor, + chunk_size: usize, + ) -> Result, TensorError> { + let transcript: Tensor = a + .iter() + .chunks(chunk_size) + .into_iter() + .scan(T::zero().unwrap(), |acc, chunk| { + let k = chunk.fold(T::zero().unwrap(), |acc, a_i| acc.clone() + a_i.clone()); + *acc = acc.clone() + k.clone(); + Some(acc.clone()) + }) + .collect(); + + Ok(transcript) + } + + /// Prod of a tensor. + /// # Arguments + /// + /// * `a` - Tensor + /// # Examples + /// ``` + /// use ezkl::tensor::Tensor; + /// use ezkl::tensor::ops::accumulated::prod; + /// let x = Tensor::::new( + /// Some(&[2, 15, 2, 1, 1, 0]), + /// &[2, 3], + /// ).unwrap(); + /// let result = prod(&x, 1).unwrap(); + /// let expected = Tensor::::new( + /// Some(&[2, 30, 60, 60, 60, 0]), + /// &[6], + /// ).unwrap(); + /// assert_eq!(result, expected); + /// ``` + pub fn prod + Add>( + a: &Tensor, + chunk_size: usize, + ) -> Result, TensorError> { + let transcript: Tensor = a + .iter() + .chunks(chunk_size) + .into_iter() + .scan(T::one().unwrap(), |acc, chunk| { + let k = chunk.fold(T::one().unwrap(), |acc, a_i| acc.clone() * a_i.clone()); + *acc = acc.clone() * k.clone(); + Some(acc.clone()) + }) + .collect(); + + Ok(transcript) + } +} diff --git a/mnist_ezkl/src/tensor/val.rs b/mnist_ezkl/src/tensor/val.rs new file mode 100644 index 0000000..e375d61 --- /dev/null +++ b/mnist_ezkl/src/tensor/val.rs @@ -0,0 +1,870 @@ +use super::{ + ops::{intercalate_values, pad, resize}, + *, +}; +use halo2_proofs::{arithmetic::Field, plonk::Instance}; + +#[derive(Debug, Clone)] +/// A [ValType] is a wrapper around Halo2 value(s). +pub enum ValType { + /// value + Value(Value), + /// assigned value + AssignedValue(Value>), + /// previously assigned value + PrevAssigned(AssignedCell), + /// constant + Constant(F), + /// assigned constant + AssignedConstant(AssignedCell, F), +} + +impl ValType { + /// Returns true if the value is previously assigned. + pub fn is_prev_assigned(&self) -> bool { + matches!( + self, + ValType::PrevAssigned(_) | ValType::AssignedConstant(..) + ) + } + /// Returns true if the value is constant. + pub fn is_constant(&self) -> bool { + matches!(self, ValType::Constant(_) | ValType::AssignedConstant(..)) + } + + /// get felt eval + pub fn get_felt_eval(&self) -> Option { + let mut res = None; + match self { + ValType::Value(v) => { + v.map(|f| { + res = Some(f); + }); + } + ValType::AssignedValue(v) => { + v.map(|f| { + res = Some(f.evaluate()); + }); + } + ValType::PrevAssigned(v) | ValType::AssignedConstant(v, ..) => { + v.value_field().map(|f| { + res = Some(f.evaluate()); + }); + } + ValType::Constant(v) => { + res = Some(*v); + } + } + res + } + + /// get_prev_assigned + pub fn get_prev_assigned(&self) -> Option> { + match self { + ValType::PrevAssigned(v) => Some(v.clone()), + ValType::AssignedConstant(v, _) => Some(v.clone()), + _ => None, + } + } +} + +impl From> for i32 { + fn from(val: ValType) -> Self { + match val { + ValType::Value(v) => { + let mut output = 0_i32; + let mut i = 0; + v.map(|y| { + let e = felt_to_i32(y); + output = e; + i += 1; + }); + output + } + ValType::AssignedValue(v) => { + let mut output = 0_i32; + let mut i = 0; + v.evaluate().map(|y| { + let e = felt_to_i32(y); + output = e; + i += 1; + }); + output + } + ValType::PrevAssigned(v) | ValType::AssignedConstant(v, ..) => { + let mut output = 0_i32; + let mut i = 0; + v.value().map(|y| { + let e = felt_to_i32(*y); + output = e; + i += 1; + }); + output + } + ValType::Constant(v) => felt_to_i32(v), + } + } +} + +impl From for ValType { + fn from(t: F) -> ValType { + ValType::Constant(t) + } +} + +impl From> for ValType { + fn from(t: Value) -> ValType { + ValType::Value(t) + } +} + +impl From>> for ValType { + fn from(t: Value>) -> ValType { + ValType::AssignedValue(t) + } +} + +impl From> for ValType { + fn from(t: AssignedCell) -> ValType { + ValType::PrevAssigned(t) + } +} + +impl TensorType for ValType +where + F: Field, +{ + fn zero() -> Option { + Some(ValType::Constant(::ZERO)) + } + fn one() -> Option { + Some(ValType::Constant(::ONE)) + } +} + +/// A [ValTensor] is a wrapper around a [Tensor] of [ValType]. +/// or a column of an [Instance]. +/// This is the type used for all intermediate values in a circuit. +/// It is also the type used for the inputs and outputs of a circuit. +#[derive(Debug, Clone)] +pub enum ValTensor { + /// A tensor of [Value], each containing a field element + Value { + /// Underlying [Tensor]. + inner: Tensor>, + /// Vector of dimensions of the tensor. + dims: Vec, + /// + scale: crate::Scale, + }, + /// A tensor backed by an [Instance] column + Instance { + /// [Instance] + inner: Column, + /// Vector of dimensions of the tensor. + dims: Vec>, + /// Current instance num + idx: usize, + /// + initial_offset: usize, + /// + scale: crate::Scale, + }, +} + +impl TensorType for ValTensor { + fn zero() -> Option { + Some(ValTensor::Value { + inner: Tensor::zero()?, + dims: vec![], + scale: 0, + }) + } +} + +impl From>> for ValTensor { + fn from(t: Tensor>) -> ValTensor { + ValTensor::Value { + inner: t.map(|x| x), + dims: t.dims().to_vec(), + scale: 1, + } + } +} + +impl From>> for ValTensor { + fn from(t: Vec>) -> ValTensor { + ValTensor::Value { + inner: t.clone().into_iter().into(), + dims: vec![t.len()], + scale: 1, + } + } +} + +impl TryFrom> for ValTensor { + type Error = Box; + fn try_from(t: Tensor) -> Result, Box> { + let visibility = t.visibility.clone(); + let dims = t.dims().to_vec(); + let inner = t.into_iter().map(|x| { + if let Some(vis) = &visibility { + match vis { + Visibility::Fixed => Ok(ValType::Constant(x)), + _ => { + Ok(Value::known(x).into()) + } + } + } + else { + Err("visibility should be set to convert a tensor of field elements to a ValTensor.".into()) + } + }).collect::, Box>>()?; + + let mut inner: Tensor> = inner.into_iter().into(); + inner.reshape(&dims)?; + + Ok(ValTensor::Value { + inner, + dims, + scale: 1, + }) + } +} + +impl From>> for ValTensor { + fn from(t: Tensor>) -> ValTensor { + ValTensor::Value { + inner: t.map(|x| x.into()), + dims: t.dims().to_vec(), + scale: 1, + } + } +} + +impl From>>> for ValTensor { + fn from(t: Tensor>>) -> ValTensor { + ValTensor::Value { + inner: t.map(|x| x.into()), + dims: t.dims().to_vec(), + scale: 1, + } + } +} + +impl From>> for ValTensor { + fn from(t: Tensor>) -> ValTensor { + ValTensor::Value { + inner: t.map(|x| x.into()), + dims: t.dims().to_vec(), + scale: 1, + } + } +} + +impl ValTensor { + /// Allocate a new [ValTensor::Instance] from the ConstraintSystem with the given tensor `dims`, optionally enabling `equality`. + pub fn new_instance( + cs: &mut ConstraintSystem, + dims: Vec>, + scale: crate::Scale, + ) -> Self { + let col = cs.instance_column(); + cs.enable_equality(col); + + ValTensor::Instance { + inner: col, + dims, + initial_offset: 0, + idx: 0, + scale, + } + } + + /// Allocate a new [ValTensor::Instance] from the ConstraintSystem with the given tensor `dims`, optionally enabling `equality`. + pub fn new_instance_from_col( + dims: Vec>, + scale: crate::Scale, + col: Column, + ) -> Self { + ValTensor::Instance { + inner: col, + dims, + idx: 0, + initial_offset: 0, + scale, + } + } + + /// + pub fn get_total_instance_len(&self) -> usize { + match self { + ValTensor::Instance { dims, .. } => dims + .iter() + .map(|x| { + if !x.is_empty() { + x.iter().product::() + } else { + 0 + } + }) + .sum(), + _ => 0, + } + } + + /// + pub fn is_instance(&self) -> bool { + matches!(self, ValTensor::Instance { .. }) + } + + /// + pub fn set_initial_instance_offset(&mut self, offset: usize) { + if let ValTensor::Instance { initial_offset, .. } = self { + *initial_offset = offset; + } + } + + /// + pub fn increment_idx(&mut self) { + if let ValTensor::Instance { idx, .. } = self { + *idx += 1; + } + } + + /// + pub fn set_idx(&mut self, val: usize) { + if let ValTensor::Instance { idx, .. } = self { + *idx = val; + } + } + + /// + pub fn get_idx(&self) -> usize { + match self { + ValTensor::Instance { idx, .. } => *idx, + _ => 0, + } + } + + /// + pub fn any_unknowns(&self) -> Result> { + match self { + ValTensor::Instance { .. } => Ok(true), + _ => Ok(self.get_inner()?.iter().any(|&x| { + let mut is_empty = true; + x.map(|_| is_empty = false); + is_empty + })), + } + } + + /// Returns true if all the [ValTensor]'s [Value]s are assigned. + pub fn all_prev_assigned(&self) -> bool { + match self { + ValTensor::Value { inner, .. } => inner.iter().all(|x| x.is_prev_assigned()), + ValTensor::Instance { .. } => false, + } + } + + /// Set the [ValTensor]'s scale. + pub fn set_scale(&mut self, scale: crate::Scale) { + match self { + ValTensor::Value { scale: s, .. } => *s = scale, + ValTensor::Instance { scale: s, .. } => *s = scale, + } + } + + /// Returns the [ValTensor]'s scale. + pub fn scale(&self) -> crate::Scale { + match self { + ValTensor::Value { scale, .. } => *scale, + ValTensor::Instance { scale, .. } => *scale, + } + } + + /// Returns the number of constants in the [ValTensor]. + pub fn num_constants(&self) -> usize { + match self { + ValTensor::Value { inner, .. } => inner.iter().filter(|x| x.is_constant()).count(), + ValTensor::Instance { .. } => 0, + } + } + + /// Fetch the underlying [Tensor] of field elements. + pub fn get_felt_evals(&self) -> Result, Box> { + let mut felt_evals: Vec = vec![]; + match self { + ValTensor::Value { + inner: v, dims: _, .. + } => { + // we have to push to an externally created vector or else vaf.map() returns an evaluation wrapped in Value<> (which we don't want) + let _ = v.map(|vaf| { + if let Some(f) = vaf.get_felt_eval() { + felt_evals.push(f); + } + }); + } + _ => return Err(Box::new(TensorError::WrongMethod)), + }; + + let mut res: Tensor = felt_evals.into_iter().into(); + res.reshape(self.dims())?; + Ok(res) + } + + /// Calls is_singleton on the inner tensor. + pub fn is_singleton(&self) -> bool { + match self { + ValTensor::Value { inner, .. } => inner.is_singleton(), + ValTensor::Instance { .. } => false, + } + } + + /// Calls `int_evals` on the inner tensor. + pub fn get_int_evals(&self) -> Result, Box> { + // finally convert to vector of integers + let mut integer_evals: Vec = vec![]; + match self { + ValTensor::Value { + inner: v, dims: _, .. + } => { + // we have to push to an externally created vector or else vaf.map() returns an evaluation wrapped in Value<> (which we don't want) + let _ = v.map(|vaf| match vaf { + ValType::Value(v) => v.map(|f| { + integer_evals.push(crate::fieldutils::felt_to_i128(f)); + }), + ValType::AssignedValue(v) => v.map(|f| { + integer_evals.push(crate::fieldutils::felt_to_i128(f.evaluate())); + }), + ValType::PrevAssigned(v) | ValType::AssignedConstant(v, ..) => { + v.value_field().map(|f| { + integer_evals.push(crate::fieldutils::felt_to_i128(f.evaluate())); + }) + } + ValType::Constant(v) => { + integer_evals.push(crate::fieldutils::felt_to_i128(v)); + Value::unknown() + } + }); + } + _ => return Err(Box::new(TensorError::WrongMethod)), + }; + Ok(integer_evals.into_iter().into()) + } + + /// Calls `pad_to_zero_rem` on the inner tensor. + pub fn pad_to_zero_rem(&mut self, n: usize) -> Result<(), Box> { + match self { + ValTensor::Value { + inner: v, dims: d, .. + } => { + *v = v.pad_to_zero_rem(n)?; + *d = v.dims().to_vec(); + } + ValTensor::Instance { .. } => { + return Err(Box::new(TensorError::WrongMethod)); + } + }; + Ok(()) + } + + /// Calls `get_slice` on the inner tensor. + pub fn get_slice(&self, indices: &[Range]) -> Result, Box> { + if indices.iter().map(|x| x.end - x.start).collect::>() == self.dims() { + return Ok(self.clone()); + } + let slice = match self { + ValTensor::Value { + inner: v, + dims: _, + scale, + } => { + let inner = v.get_slice(indices)?; + let dims = inner.dims().to_vec(); + ValTensor::Value { + inner, + dims, + scale: *scale, + } + } + _ => return Err(Box::new(TensorError::WrongMethod)), + }; + Ok(slice) + } + + /// Calls `get_single_elem` on the inner tensor. + pub fn get_single_elem(&self, index: usize) -> Result, Box> { + let slice = match self { + ValTensor::Value { + inner: v, + dims: _, + scale, + } => { + let inner = Tensor::from(vec![v.get_flat_index(index)].into_iter()); + ValTensor::Value { + inner, + dims: vec![1], + scale: *scale, + } + } + _ => return Err(Box::new(TensorError::WrongMethod)), + }; + Ok(slice) + } + + /// Fetches the inner tensor as a `Tensor` + pub fn get_inner_tensor(&self) -> Result<&Tensor>, TensorError> { + Ok(match self { + ValTensor::Value { inner: v, .. } => v, + ValTensor::Instance { .. } => return Err(TensorError::WrongMethod), + }) + } + + /// Fetches the inner tensor as a `Tensor` + pub fn get_inner_tensor_mut(&mut self) -> Result<&mut Tensor>, TensorError> { + Ok(match self { + ValTensor::Value { inner: v, .. } => v, + ValTensor::Instance { .. } => return Err(TensorError::WrongMethod), + }) + } + + /// Fetches the inner tensor as a `Tensor>` + pub fn get_inner(&self) -> Result>, TensorError> { + Ok(match self { + ValTensor::Value { inner: v, .. } => v.map(|x| match x { + ValType::Value(v) => v, + ValType::AssignedValue(v) => v.evaluate(), + ValType::PrevAssigned(v) | ValType::AssignedConstant(v, ..) => { + v.value_field().evaluate() + } + ValType::Constant(v) => Value::known(v), + }), + ValTensor::Instance { .. } => return Err(TensorError::WrongMethod), + }) + } + /// Calls `expand` on the inner tensor. + pub fn expand(&mut self, dims: &[usize]) -> Result<(), Box> { + match self { + ValTensor::Value { + inner: v, dims: d, .. + } => { + *v = v.expand(dims)?; + *d = v.dims().to_vec(); + } + ValTensor::Instance { .. } => { + return Err(Box::new(TensorError::WrongMethod)); + } + }; + Ok(()) + } + + /// Calls `move_axis` on the inner tensor. + pub fn move_axis(&mut self, source: usize, destination: usize) -> Result<(), Box> { + match self { + ValTensor::Value { + inner: v, dims: d, .. + } => { + *v = v.move_axis(source, destination)?; + *d = v.dims().to_vec(); + } + ValTensor::Instance { .. } => { + return Err(Box::new(TensorError::WrongMethod)); + } + }; + Ok(()) + } + + /// Sets the [ValTensor]'s shape. + pub fn reshape(&mut self, new_dims: &[usize]) -> Result<(), Box> { + match self { + ValTensor::Value { + inner: v, dims: d, .. + } => { + v.reshape(new_dims)?; + *d = v.dims().to_vec(); + } + ValTensor::Instance { dims: d, idx, .. } => { + if d[*idx].iter().product::() != new_dims.iter().product::() { + return Err(Box::new(TensorError::DimError)); + } + d[*idx] = new_dims.to_vec(); + } + }; + Ok(()) + } + + /// Sets the [ValTensor]'s shape. + pub fn slice( + &mut self, + axis: &usize, + start: &usize, + end: &usize, + ) -> Result<(), Box> { + match self { + ValTensor::Value { + inner: v, dims: d, .. + } => { + *v = crate::tensor::ops::slice(v, axis, start, end)?; + *d = v.dims().to_vec(); + } + ValTensor::Instance { .. } => { + return Err(Box::new(TensorError::WrongMethod)); + } + }; + Ok(()) + } + + /// Calls `flatten` on the inner [Tensor]. + pub fn flatten(&mut self) { + match self { + ValTensor::Value { + inner: v, dims: d, .. + } => { + v.flatten(); + *d = v.dims().to_vec(); + } + ValTensor::Instance { dims: d, idx, .. } => { + d[*idx] = vec![d[*idx].iter().product()]; + } + } + } + + /// Calls `duplicate_every_n` on the inner [Tensor]. + pub fn duplicate_every_n( + &mut self, + n: usize, + num_repeats: usize, + initial_offset: usize, + ) -> Result<(), TensorError> { + match self { + ValTensor::Value { + inner: v, dims: d, .. + } => { + *v = v.duplicate_every_n(n, num_repeats, initial_offset)?; + *d = v.dims().to_vec(); + } + ValTensor::Instance { .. } => { + return Err(TensorError::WrongMethod); + } + } + Ok(()) + } + + /// gets constants + pub fn get_const_zero_indices(&self) -> Result, TensorError> { + match self { + ValTensor::Value { inner: v, .. } => { + let mut indices = vec![]; + for (i, e) in v.iter().enumerate() { + if let ValType::Constant(r) = e { + if *r == F::ZERO { + indices.push(i); + } + } else if let ValType::AssignedConstant(_, r) = e { + if *r == F::ZERO { + indices.push(i); + } + } + } + Ok(indices) + } + ValTensor::Instance { .. } => Err(TensorError::WrongMethod), + } + } + + /// gets constants + pub fn get_const_indices(&self) -> Result, TensorError> { + match self { + ValTensor::Value { inner: v, .. } => { + let mut indices = vec![]; + for (i, e) in v.iter().enumerate() { + if let ValType::Constant(_) = e { + indices.push(i); + } else if let ValType::AssignedConstant(_, _) = e { + indices.push(i); + } + } + Ok(indices) + } + ValTensor::Instance { .. } => Err(TensorError::WrongMethod), + } + } + + /// calls `remove_indices` on the inner [Tensor]. + pub fn remove_indices( + &mut self, + indices: &mut [usize], + is_sorted: bool, + ) -> Result<(), TensorError> { + match self { + ValTensor::Value { + inner: v, dims: d, .. + } => { + // this is very slow how can we speed this up ? + *v = v.remove_indices(indices, is_sorted)?; + *d = v.dims().to_vec(); + } + ValTensor::Instance { .. } => { + return Err(TensorError::WrongMethod); + } + } + Ok(()) + } + + /// Calls `duplicate_every_n` on the inner [Tensor]. + pub fn remove_every_n( + &mut self, + n: usize, + num_repeats: usize, + initial_offset: usize, + ) -> Result<(), TensorError> { + match self { + ValTensor::Value { + inner: v, dims: d, .. + } => { + *v = v.remove_every_n(n, num_repeats, initial_offset)?; + *d = v.dims().to_vec(); + } + ValTensor::Instance { .. } => { + return Err(TensorError::WrongMethod); + } + } + Ok(()) + } + + /// Calls `intercalate_values` on the inner [Tensor]. + pub fn intercalate_values( + &mut self, + value: ValType, + stride: usize, + axis: usize, + ) -> Result<(), TensorError> { + match self { + ValTensor::Value { + inner: v, dims: d, .. + } => { + *v = intercalate_values(v, value, stride, axis)?; + *d = v.dims().to_vec(); + } + ValTensor::Instance { .. } => { + return Err(TensorError::WrongMethod); + } + } + Ok(()) + } + /// Calls `resize` on the inner [Tensor]. + pub fn resize(&mut self, scales: &[usize]) -> Result<(), TensorError> { + match self { + ValTensor::Value { + inner: v, dims: d, .. + } => { + *v = resize(v, scales)?; + *d = v.dims().to_vec(); + } + ValTensor::Instance { .. } => { + return Err(TensorError::WrongMethod); + } + }; + Ok(()) + } + /// Calls `pad` on the inner [Tensor]. + pub fn pad(&mut self, padding: [(usize, usize); 2]) -> Result<(), TensorError> { + match self { + ValTensor::Value { + inner: v, dims: d, .. + } => { + *v = pad(v, padding)?; + *d = v.dims().to_vec(); + } + ValTensor::Instance { .. } => { + return Err(TensorError::WrongMethod); + } + } + Ok(()) + } + + /// Calls `len` on the inner [Tensor]. + pub fn len(&self) -> usize { + match self { + ValTensor::Value { dims, .. } => { + if !dims.is_empty() && (dims != &[0]) { + dims.iter().product::() + } else { + 0 + } + } + ValTensor::Instance { dims, idx, .. } => { + let dims = dims[*idx].clone(); + if !dims.is_empty() && (dims != [0]) { + dims.iter().product::() + } else { + 0 + } + } + } + } + + /// + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + + /// Calls `concats` on the inner [Tensor]. + pub fn concat(&self, other: Self) -> Result { + let res = match (self, other) { + (ValTensor::Value { inner: v1, .. }, ValTensor::Value { inner: v2, .. }) => { + ValTensor::from(Tensor::new(Some(&[v1.clone(), v2]), &[2])?.combine()?) + } + _ => { + return Err(TensorError::WrongMethod); + } + }; + Ok(res) + } + + /// Calls `concats` on the inner [Tensor]. + pub fn concat_axis(&self, other: Self, axis: &usize) -> Result { + let res = match (self, other) { + (ValTensor::Value { inner: v1, .. }, ValTensor::Value { inner: v2, .. }) => { + let v = crate::tensor::ops::concat(&[v1, &v2], *axis)?; + ValTensor::from(v) + } + _ => { + return Err(TensorError::WrongMethod); + } + }; + Ok(res) + } + + /// Returns the `dims` attribute of the [ValTensor]. + pub fn dims(&self) -> &[usize] { + match self { + ValTensor::Value { dims: d, .. } => d, + ValTensor::Instance { dims: d, idx, .. } => &d[*idx], + } + } + /// A [String] representation of the [ValTensor] for display, for example in showing intermediate values in a computational graph. + pub fn show(&self) -> String { + match self.clone() { + ValTensor::Value { + inner: v, dims: _, .. + } => { + let r: Tensor = v.map(|x| x.into()); + if r.len() > 10 { + let start = r[..5].to_vec(); + let end = r[r.len() - 5..].to_vec(); + // print the two split by ... in the middle + format!( + "[{} ... {}]", + start.iter().map(|x| format!("{}", x)).join(", "), + end.iter().map(|x| format!("{}", x)).join(", ") + ) + } else { + format!("{:?}", r) + } + } + _ => "ValTensor not PrevAssigned".into(), + } + } +} diff --git a/mnist_ezkl/src/tensor/var.rs b/mnist_ezkl/src/tensor/var.rs new file mode 100644 index 0000000..66cd5e3 --- /dev/null +++ b/mnist_ezkl/src/tensor/var.rs @@ -0,0 +1,586 @@ +use std::collections::HashSet; + +use log::{error, warn}; + +use crate::circuit::CheckMode; + +use super::*; +/// A wrapper around Halo2's `Column` or `Column`. +/// Typically assign [ValTensor]s to [VarTensor]s when laying out a circuit. +#[derive(Clone, Default, Debug, PartialEq, Eq)] +pub enum VarTensor { + /// A VarTensor for holding Advice values, which are assigned at proving time. + Advice { + /// Vec of Advice columns, we have [[xx][xx][xx]...] where each inner vec is xx columns + inner: Vec>>, + /// + num_inner_cols: usize, + /// Number of rows available to be used in each column of the storage + col_size: usize, + }, + /// Dummy var + Dummy { + /// + num_inner_cols: usize, + /// Number of rows available to be used in each column of the storage + col_size: usize, + }, + /// Empty var + #[default] + Empty, +} + +impl VarTensor { + /// + pub fn is_advice(&self) -> bool { + match self { + VarTensor::Advice { .. } => true, + _ => false, + } + } + + /// + pub fn max_rows(cs: &ConstraintSystem, logrows: usize) -> usize { + let base = 2u32; + base.pow(logrows as u32) as usize - cs.blinding_factors() - 1 + } + + /// Create a new VarTensor::Advice that is unblinded + /// Arguments + /// * `cs` - The constraint system + /// * `logrows` - log2 number of rows in the matrix, including any system and blinding rows. + /// * `capacity` - The number of advice cells to allocate + pub fn new_unblinded_advice( + cs: &mut ConstraintSystem, + logrows: usize, + num_inner_cols: usize, + capacity: usize, + ) -> Self { + let max_rows = Self::max_rows(cs, logrows) * num_inner_cols; + + let mut modulo = (capacity / max_rows) + 1; + // we add a buffer for duplicated rows (we get at most 1 duplicated row per column) + modulo = ((capacity + modulo) / max_rows) + 1; + let mut advices = vec![]; + + if modulo > 1 { + warn!( + "using column duplication for {} unblinded advice blocks", + modulo - 1 + ); + } + + for _ in 0..modulo { + let mut inner = vec![]; + for _ in 0..num_inner_cols { + let col = cs.unblinded_advice_column(); + cs.enable_equality(col); + inner.push(col); + } + advices.push(inner); + } + + VarTensor::Advice { + inner: advices, + num_inner_cols, + col_size: max_rows, + } + } + + /// Create a new VarTensor::Advice + /// Arguments + /// * `cs` - The constraint system + /// * `logrows` - log2 number of rows in the matrix, including any system and blinding rows. + /// * `capacity` - The number of advice cells to allocate + pub fn new_advice( + cs: &mut ConstraintSystem, + logrows: usize, + num_inner_cols: usize, + capacity: usize, + ) -> Self { + let max_rows = Self::max_rows(cs, logrows); + let max_assignments = Self::max_rows(cs, logrows) * num_inner_cols; + + let mut modulo = (capacity / max_assignments) + 1; + // we add a buffer for duplicated rows (we get at most 1 duplicated row per column) + modulo = ((capacity + modulo) / max_assignments) + 1; + let mut advices = vec![]; + + if modulo > 1 { + warn!("using column duplication for {} advice blocks", modulo - 1); + } + + for _ in 0..modulo { + let mut inner = vec![]; + for _ in 0..num_inner_cols { + let col = cs.advice_column(); + cs.enable_equality(col); + inner.push(col); + } + advices.push(inner); + } + + VarTensor::Advice { + inner: advices, + num_inner_cols, + col_size: max_rows, + } + } + + /// Initializes fixed columns to support the VarTensor::Advice + /// Arguments + /// * `cs` - The constraint system + /// * `logrows` - log2 number of rows in the matrix, including any system and blinding rows. + /// * `capacity` - The number of advice cells to allocate + pub fn constant_cols( + cs: &mut ConstraintSystem, + logrows: usize, + num_constants: usize, + module_requires_fixed: bool, + ) -> usize { + if num_constants == 0 && !module_requires_fixed { + return 0; + } else if num_constants == 0 && module_requires_fixed { + let col = cs.fixed_column(); + cs.enable_constant(col); + return 1; + } + + let max_rows = Self::max_rows(cs, logrows); + + let mut modulo = num_constants / max_rows + 1; + // we add a buffer for duplicated rows (we get at most 1 duplicated row per column) + modulo = (num_constants + modulo) / max_rows + 1; + + if modulo > 1 { + warn!("using column duplication for {} fixed columns", modulo - 1); + } + + for _ in 0..modulo { + let col = cs.fixed_column(); + cs.enable_constant(col); + } + modulo + } + + /// Create a new VarTensor::Dummy + pub fn dummy(logrows: usize, num_inner_cols: usize) -> Self { + let base = 2u32; + let max_rows = base.pow(logrows as u32) as usize - 6; + VarTensor::Dummy { + col_size: max_rows, + num_inner_cols, + } + } + + /// Gets the dims of the object the VarTensor represents + pub fn num_blocks(&self) -> usize { + match self { + VarTensor::Advice { inner, .. } => inner.len(), + _ => 0, + } + } + + /// Num inner cols + pub fn num_inner_cols(&self) -> usize { + match self { + VarTensor::Advice { num_inner_cols, .. } | VarTensor::Dummy { num_inner_cols, .. } => { + *num_inner_cols + } + _ => 0, + } + } + + /// Total number of columns + pub fn num_cols(&self) -> usize { + match self { + VarTensor::Advice { inner, .. } => inner[0].len() * inner.len(), + _ => 0, + } + } + + /// Gets the size of each column + pub fn col_size(&self) -> usize { + match self { + VarTensor::Advice { col_size, .. } | VarTensor::Dummy { col_size, .. } => *col_size, + _ => 0, + } + } + + /// Gets the size of each column + pub fn block_size(&self) -> usize { + match self { + VarTensor::Advice { + num_inner_cols, + col_size, + .. + } + | VarTensor::Dummy { + col_size, + num_inner_cols, + .. + } => *col_size * num_inner_cols, + _ => 0, + } + } + + /// Take a linear coordinate and output the (column, row) position in the storage block. + pub fn cartesian_coord(&self, linear_coord: usize) -> (usize, usize, usize) { + // x indexes over blocks of size num_inner_cols + let x = linear_coord / self.block_size(); + // y indexes over the cols inside a block + let y = linear_coord % self.num_inner_cols(); + // z indexes over the rows inside a col + let z = (linear_coord - x * self.block_size()) / self.num_inner_cols(); + (x, y, z) + } +} + +impl VarTensor { + /// Retrieve the value of a specific cell in the tensor. + pub fn query_rng( + &self, + meta: &mut VirtualCells<'_, F>, + x: usize, + y: usize, + z: i32, + rng: usize, + ) -> Result>, halo2_proofs::plonk::Error> { + match &self { + // when advice we have 1 col per row + VarTensor::Advice { inner: advices, .. } => { + let c = Tensor::from( + // this should fail if dims is empty, should be impossible + (0..rng).map(|i| meta.query_advice(advices[x][y], Rotation(z + i as i32))), + ); + Ok(c) + } + _ => { + error!("VarTensor was not initialized"); + Err(halo2_proofs::plonk::Error::Synthesis) + } + } + } + + /// Retrieve the value of a specific block at an offset in the tensor. + pub fn query_whole_block( + &self, + meta: &mut VirtualCells<'_, F>, + x: usize, + z: i32, + rng: usize, + ) -> Result>, halo2_proofs::plonk::Error> { + match &self { + // when advice we have 1 col per row + VarTensor::Advice { inner: advices, .. } => { + let c = Tensor::from({ + // this should fail if dims is empty, should be impossible + let cartesian = (0..rng).cartesian_product(0..self.num_inner_cols()); + cartesian.map(|(i, y)| meta.query_advice(advices[x][y], Rotation(z + i as i32))) + }); + Ok(c) + } + _ => { + error!("VarTensor was not initialized"); + Err(halo2_proofs::plonk::Error::Synthesis) + } + } + } + + /// Assigns a constant value to a specific cell in the tensor. + pub fn assign_constant( + &self, + region: &mut Region, + offset: usize, + constant: F, + ) -> Result, halo2_proofs::plonk::Error> { + let (x, y, z) = self.cartesian_coord(offset); + match &self { + VarTensor::Advice { inner: advices, .. } => { + region.assign_advice_from_constant(|| "constant", advices[x][y], z, constant) + } + _ => { + error!("VarTensor was not initialized"); + Err(halo2_proofs::plonk::Error::Synthesis) + } + } + } + + /// Assigns [ValTensor] to the columns of the inner tensor. + pub fn assign_with_omissions( + &self, + region: &mut Region, + offset: usize, + values: &ValTensor, + omissions: &HashSet<&usize>, + ) -> Result, halo2_proofs::plonk::Error> { + let mut assigned_coord = 0; + let mut res: ValTensor = match values { + ValTensor::Instance { .. } => { + unimplemented!("cannot assign instance to advice columns with omissions") + } + ValTensor::Value { inner: v, .. } => Ok::<_, halo2_proofs::plonk::Error>( + v.enum_map(|coord, k| { + if omissions.contains(&coord) { + return Ok(k); + } + let cell = self.assign_value(region, offset, k.clone(), assigned_coord)?; + assigned_coord += 1; + + match k { + ValType::Constant(f) => Ok::, halo2_proofs::plonk::Error>( + ValType::AssignedConstant(cell, f), + ), + ValType::AssignedConstant(_, f) => Ok(ValType::AssignedConstant(cell, f)), + _ => Ok(ValType::PrevAssigned(cell)), + } + })? + .into(), + ), + }?; + res.set_scale(values.scale()); + Ok(res) + } + + /// Assigns [ValTensor] to the columns of the inner tensor. + pub fn assign( + &self, + region: &mut Region, + offset: usize, + values: &ValTensor, + ) -> Result, halo2_proofs::plonk::Error> { + let mut res: ValTensor = match values { + ValTensor::Instance { + inner: instance, + dims, + idx, + initial_offset, + .. + } => match &self { + VarTensor::Advice { inner: v, .. } => { + let total_offset: usize = initial_offset + + dims[..*idx] + .iter() + .map(|x| x.iter().product::()) + .sum::(); + let dims = &dims[*idx]; + // this should never ever fail + let t: Tensor = Tensor::new(None, dims).unwrap(); + Ok(t.enum_map(|coord, _| { + let (x, y, z) = self.cartesian_coord(offset + coord); + region.assign_advice_from_instance( + || "pub input anchor", + *instance, + coord + total_offset, + v[x][y], + z, + ) + })? + .into()) + } + _ => { + error!("Instance is only supported for advice columns"); + Err(halo2_proofs::plonk::Error::Synthesis) + } + }, + ValTensor::Value { inner: v, .. } => Ok(v + .enum_map(|coord, k| { + let cell = self.assign_value(region, offset, k.clone(), coord)?; + match k { + ValType::Constant(f) => Ok::, halo2_proofs::plonk::Error>( + ValType::AssignedConstant(cell, f), + ), + ValType::AssignedConstant(_, f) => Ok(ValType::AssignedConstant(cell, f)), + _ => Ok(ValType::PrevAssigned(cell)), + } + })? + .into()), + }?; + res.set_scale(values.scale()); + Ok(res) + } + + /// Assigns specific values (`ValTensor`) to the columns of the inner tensor but allows for column wrapping for accumulated operations. + /// Duplication occurs by copying the last cell of the column to the first cell next column and creating a copy constraint between the two. + pub fn dummy_assign_with_duplication( + &self, + row: usize, + offset: usize, + values: &ValTensor, + single_inner_col: bool, + ) -> Result<(ValTensor, usize, usize), halo2_proofs::plonk::Error> { + match values { + ValTensor::Instance { .. } => unimplemented!("duplication is not supported on instance columns. increase K if you require more rows."), + ValTensor::Value { inner: v, dims , ..} => { + let duplication_freq = if single_inner_col { + self.col_size() + } else { + self.block_size() + }; + + let num_repeats = if single_inner_col { + 1 + } else { + self.num_inner_cols() + }; + + let duplication_offset = if single_inner_col { + row + } else { + offset + }; + + + // duplicates every nth element to adjust for column overflow + let mut res: ValTensor = v.duplicate_every_n(duplication_freq, num_repeats, duplication_offset).unwrap().into(); + let total_used_len = res.len(); + let total_constants = res.num_constants(); + res.remove_every_n(duplication_freq, num_repeats, duplication_offset).unwrap(); + + res.reshape(dims).unwrap(); + res.set_scale(values.scale()); + + Ok((res, total_used_len, total_constants)) + } + } + } + + /// Assigns specific values (`ValTensor`) to the columns of the inner tensor but allows for column wrapping for accumulated operations. + /// Duplication occurs by copying the last cell of the column to the first cell next column and creating a copy constraint between the two. + pub fn assign_with_duplication( + &self, + region: &mut Region, + row: usize, + offset: usize, + values: &ValTensor, + check_mode: &CheckMode, + single_inner_col: bool, + ) -> Result<(ValTensor, usize, usize), halo2_proofs::plonk::Error> { + let mut prev_cell = None; + + match values { + ValTensor::Instance { .. } => unimplemented!("duplication is not supported on instance columns. increase K if you require more rows."), + ValTensor::Value { inner: v, dims , ..} => { + + let duplication_freq = if single_inner_col { + self.col_size() + } else { + self.block_size() + }; + + let num_repeats = if single_inner_col { + 1 + } else { + self.num_inner_cols() + }; + + let duplication_offset = if single_inner_col { + row + } else { + offset + }; + + // duplicates every nth element to adjust for column overflow + let v = v.duplicate_every_n(duplication_freq, num_repeats, duplication_offset).unwrap(); + let mut res: ValTensor = { + v.enum_map(|coord, k| { + + let step = if !single_inner_col { + 1 + } else { + self.num_inner_cols() + }; + + let (x, y, z) = self.cartesian_coord(offset + coord * step); + if matches!(check_mode, CheckMode::SAFE) && coord > 0 && z == 0 && y == 0 { + // assert that duplication occurred correctly + assert_eq!(Into::::into(k.clone()), Into::::into(v[coord - 1].clone())); + }; + + let cell = self.assign_value(region, offset, k.clone(), coord * step)?; + + if single_inner_col { + if z == 0 { + // if we are at the end of the column, we need to copy the cell to the next column + prev_cell = Some(cell.clone()); + } else if coord > 0 && z == 0 && single_inner_col { + if let Some(prev_cell) = prev_cell.as_ref() { + region.constrain_equal(prev_cell.cell(),cell.cell())?; + } else { + error!("Error copy-constraining previous value: {:?}", (x,y)); + return Err(halo2_proofs::plonk::Error::Synthesis); + } + }} + + match k { + ValType::Constant(f) => { + Ok(ValType::AssignedConstant(cell, f)) + }, + ValType::AssignedConstant(_, f) => { + Ok(ValType::AssignedConstant(cell, f)) + }, + _ => { + Ok(ValType::PrevAssigned(cell)) + } + } + + })?.into()}; + let total_used_len = res.len(); + let total_constants = res.num_constants(); + res.remove_every_n(duplication_freq, num_repeats, duplication_offset).unwrap(); + + res.reshape(dims).unwrap(); + res.set_scale(values.scale()); + + if matches!(check_mode, CheckMode::SAFE) { + // during key generation this will be 0 so we use this as a flag to check + // TODO: this isn't very safe and would be better to get the phase directly + let is_assigned = !Into::>::into(res.clone().get_inner().unwrap()) + .iter() + .all(|&x| x == 0); + if is_assigned { + assert_eq!( + Into::>::into(values.get_inner().unwrap()), + Into::>::into(res.get_inner().unwrap()) + )}; + } + + Ok((res, total_used_len, total_constants)) + } + } + } + + fn assign_value( + &self, + region: &mut Region, + offset: usize, + k: ValType, + coord: usize, + ) -> Result, halo2_proofs::plonk::Error> { + let (x, y, z) = self.cartesian_coord(offset + coord); + match k { + ValType::Value(v) => match &self { + VarTensor::Advice { inner: advices, .. } => { + region.assign_advice(|| "k", advices[x][y], z, || v) + } + _ => unimplemented!(), + }, + ValType::PrevAssigned(v) | ValType::AssignedConstant(v, ..) => match &self { + VarTensor::Advice { inner: advices, .. } => { + v.copy_advice(|| "k", region, advices[x][y], z) + } + _ => { + error!("PrevAssigned is only supported for advice columns"); + Err(halo2_proofs::plonk::Error::Synthesis) + } + }, + ValType::AssignedValue(v) => match &self { + VarTensor::Advice { inner: advices, .. } => region + .assign_advice(|| "k", advices[x][y], z, || v) + .map(|a| a.evaluate()), + _ => unimplemented!(), + }, + ValType::Constant(v) => self.assign_constant(region, offset + coord, v), + } + } +} diff --git a/mnist_ezkl/src/wasm.rs b/mnist_ezkl/src/wasm.rs new file mode 100644 index 0000000..dbeaa19 --- /dev/null +++ b/mnist_ezkl/src/wasm.rs @@ -0,0 +1,598 @@ +use crate::circuit::modules::elgamal::ElGamalCipher; +use crate::circuit::modules::poseidon::spec::{PoseidonSpec, POSEIDON_RATE, POSEIDON_WIDTH}; +use crate::circuit::modules::poseidon::PoseidonChip; +use crate::circuit::modules::Module; +use crate::fieldutils::felt_to_i128; +use crate::fieldutils::i128_to_felt; +use crate::graph::modules::POSEIDON_LEN_GRAPH; +use crate::graph::quantize_float; +use crate::graph::scale_to_multiplier; +use halo2_proofs::plonk::*; +use halo2_proofs::poly::commitment::{CommitmentScheme, ParamsProver}; +use halo2_proofs::poly::kzg::{ + commitment::{KZGCommitmentScheme, ParamsKZG}, + strategy::SingleStrategy as KZGSingleStrategy, +}; +use halo2_solidity_verifier::encode_calldata; +use halo2curves::bn256::{Bn256, Fr, G1Affine}; +use halo2curves::ff::{FromUniformBytes, PrimeField}; +use rand::rngs::StdRng; +use rand::SeedableRng; + +use crate::tensor::TensorType; +use wasm_bindgen::prelude::*; +use wasm_bindgen_console_logger::DEFAULT_LOGGER; + +use console_error_panic_hook; + +#[cfg(feature = "web")] +pub use wasm_bindgen_rayon::init_thread_pool; + +#[wasm_bindgen] +/// Initialize logger for wasm +pub fn init_logger() { + log::set_logger(&DEFAULT_LOGGER).unwrap(); +} + +#[wasm_bindgen] +/// Initialize panic hook for wasm +pub fn init_panic_hook() { + console_error_panic_hook::set_once(); +} + +use crate::graph::{GraphCircuit, GraphSettings}; +use crate::pfsys::{create_proof_circuit_kzg, verify_proof_circuit_kzg}; + +/// Wrapper around the halo2 encode call data method +#[wasm_bindgen] +#[allow(non_snake_case)] +pub fn encodeVerifierCalldata( + proof: wasm_bindgen::Clamped>, + vk_address: Option>, +) -> Result, JsError> { + let snark: crate::pfsys::Snark = serde_json::from_slice(&proof[..]) + .map_err(|e| JsError::new(&format!("Failed to deserialize proof: {}", e)))?; + + let vk_address: Option<[u8; 20]> = if let Some(vk_address) = vk_address { + let array: [u8; 20] = serde_json::from_slice(&vk_address[..]) + .map_err(|e| JsError::new(&format!("Failed to deserialize vk address: {}", e)))?; + Some(array) + } else { + None + }; + + let flattened_instances = snark.instances.into_iter().flatten(); + + let encoded = encode_calldata( + vk_address, + &snark.proof, + &flattened_instances.collect::>(), + ); + + Ok(encoded) +} + +/// Converts 4 u64s to a field element +#[wasm_bindgen] +#[allow(non_snake_case)] +pub fn vecU64ToFelt(array: wasm_bindgen::Clamped>) -> Result { + let felt: Fr = serde_json::from_slice(&array[..]) + .map_err(|e| JsError::new(&format!("Failed to deserialize field element: {}", e)))?; + Ok(format!("{:?}", felt)) +} + +/// Converts 4 u64s representing a field element directly to an integer +#[wasm_bindgen] +#[allow(non_snake_case)] +pub fn vecU64ToInt( + array: wasm_bindgen::Clamped>, +) -> Result>, JsError> { + let felt: Fr = serde_json::from_slice(&array[..]) + .map_err(|e| JsError::new(&format!("Failed to deserialize field element: {}", e)))?; + Ok(wasm_bindgen::Clamped( + serde_json::to_vec(&felt_to_i128(felt)) + .map_err(|e| JsError::new(&format!("Failed to serialize integer: {}", e)))?, + )) +} + +/// Converts 4 u64s representing a field element directly to a (rescaled from fixed point scaling) floating point +#[wasm_bindgen] +#[allow(non_snake_case)] +pub fn vecU64ToFloat( + array: wasm_bindgen::Clamped>, + scale: crate::Scale, +) -> Result { + let felt: Fr = serde_json::from_slice(&array[..]) + .map_err(|e| JsError::new(&format!("Failed to deserialize field element: {}", e)))?; + let int_rep = felt_to_i128(felt); + let multiplier = scale_to_multiplier(scale); + Ok(int_rep as f64 / multiplier) +} + +/// Converts a floating point element to 4 u64s representing a fixed point field element +#[wasm_bindgen] +#[allow(non_snake_case)] +pub fn floatToVecU64( + input: f64, + scale: crate::Scale, +) -> Result>, JsError> { + let int_rep = + quantize_float(&input, 0.0, scale).map_err(|e| JsError::new(&format!("{}", e)))?; + let felt = i128_to_felt(int_rep); + let vec = crate::pfsys::field_to_vecu64_montgomery::(&felt); + Ok(wasm_bindgen::Clamped(serde_json::to_vec(&vec).map_err( + |e| JsError::new(&format!("Failed to serialize vecu64_montgomery{}", e)), + )?)) +} + +/// Converts a buffer to vector of 4 u64s representing a fixed point field element +#[wasm_bindgen] +#[allow(non_snake_case)] +pub fn bufferToVecOfVecU64( + buffer: wasm_bindgen::Clamped>, +) -> Result>, JsError> { + // Convert the buffer to a slice + let buffer: &[u8] = &buffer; + + // Divide the buffer into chunks of 64 bytes + let chunks = buffer.chunks_exact(16); + + // Get the remainder + let remainder = chunks.remainder(); + + // Add 0s to the remainder to make it 64 bytes + let mut remainder = remainder.to_vec(); + + // Collect chunks into a Vec<[u8; 16]>. + let chunks: Result, JsError> = chunks + .map(|slice| { + let array: [u8; 16] = slice + .try_into() + .map_err(|_| JsError::new("failed to slice input chunks"))?; + Ok(array) + }) + .collect(); + + let mut chunks = chunks?; + + if remainder.len() != 0 { + remainder.resize(16, 0); + // Convert the Vec to [u8; 16] + let remainder_array: [u8; 16] = remainder + .try_into() + .map_err(|_| JsError::new("failed to slice remainder"))?; + // append the remainder to the chunks + chunks.push(remainder_array); + } + + // Convert each chunk to a field element + let field_elements: Vec = chunks + .iter() + .map(|x| PrimeField::from_u128(u8_array_to_u128_le(*x))) + .collect(); + + Ok(wasm_bindgen::Clamped( + serde_json::to_vec(&field_elements) + .map_err(|e| JsError::new(&format!("Failed to serialize field elements: {}", e)))?, + )) +} + +/// Generate a poseidon hash in browser. Input message +#[wasm_bindgen] +#[allow(non_snake_case)] +pub fn poseidonHash( + message: wasm_bindgen::Clamped>, +) -> Result>, JsError> { + let message: Vec = serde_json::from_slice(&message[..]) + .map_err(|e| JsError::new(&format!("Failed to deserialize message: {}", e)))?; + + let output = + PoseidonChip::::run( + message.clone(), + ) + .map_err(|e| JsError::new(&format!("{}", e)))?; + + Ok(wasm_bindgen::Clamped(serde_json::to_vec(&output).map_err( + |e| JsError::new(&format!("Failed to serialize poseidon hash output: {}", e)), + )?)) +} + +/// Generates random elgamal variables from a random seed value in browser. +/// Make sure input seed comes a secure source of randomness +#[wasm_bindgen] +#[allow(non_snake_case)] +pub fn elgamalGenRandom(rng: wasm_bindgen::Clamped>) -> Result, JsError> { + let seed: &[u8] = &rng; + let mut rng = StdRng::from_seed( + seed.try_into() + .map_err(|e| JsError::new(&format!("{}", e)))?, + ); + + let output = crate::circuit::modules::elgamal::ElGamalVariables::gen_random(&mut rng); + + serde_json::to_vec(&output) + .map_err(|e| JsError::new(&format!("Failed to serialize elgamal variables: {}", e))) +} + +/// Encrypt using elgamal in browser. Input message +#[wasm_bindgen] +#[allow(non_snake_case)] +pub fn elgamalEncrypt( + pk: wasm_bindgen::Clamped>, + message: wasm_bindgen::Clamped>, + r: wasm_bindgen::Clamped>, +) -> Result, JsError> { + let pk: G1Affine = serde_json::from_slice(&pk[..]) + .map_err(|e| JsError::new(&format!("Failed to deserialize pk: {}", e)))?; + let message: Vec = serde_json::from_slice(&message[..]) + .map_err(|e| JsError::new(&format!("Failed to deserialize message: {}", e)))?; + let r: Fr = serde_json::from_slice(&r[..]) + .map_err(|e| JsError::new(&format!("Failed to deserialize r: {}", e)))?; + + let output = crate::circuit::modules::elgamal::ElGamalGadget::encrypt(pk, message, r); + + serde_json::to_vec(&output) + .map_err(|e| JsError::new(&format!("Failed to serialize cipher {}", e))) +} + +/// Decrypt using elgamal in browser. Input message +#[wasm_bindgen] +#[allow(non_snake_case)] +pub fn elgamalDecrypt( + cipher: wasm_bindgen::Clamped>, + sk: wasm_bindgen::Clamped>, +) -> Result, JsError> { + let sk: Fr = serde_json::from_slice(&sk[..]) + .map_err(|e| JsError::new(&format!("Failed to deserialize sk: {}", e)))?; + + let cipher: ElGamalCipher = serde_json::from_slice(&cipher[..]) + .map_err(|e| JsError::new(&format!("Failed to deserialize cipher: {}", e)))?; + + let output = crate::circuit::modules::elgamal::ElGamalGadget::decrypt(&cipher, sk); + + serde_json::to_vec(&output) + .map_err(|e| JsError::new(&format!("Failed to serialize decrypted cipher: {}", e))) +} + +/// Generate a witness file from input.json, compiled model and a settings.json file. +#[wasm_bindgen] +#[allow(non_snake_case)] +pub fn genWitness( + compiled_circuit: wasm_bindgen::Clamped>, + input: wasm_bindgen::Clamped>, +) -> Result, JsError> { + let mut circuit: crate::graph::GraphCircuit = bincode::deserialize(&compiled_circuit[..]) + .map_err(|e| JsError::new(&format!("Failed to deserialize compiled model: {}", e)))?; + let input: crate::graph::input::GraphData = serde_json::from_slice(&input[..]) + .map_err(|e| JsError::new(&format!("Failed to deserialize input: {}", e)))?; + + let mut input = circuit + .load_graph_input(&input) + .map_err(|e| JsError::new(&format!("{}", e)))?; + + let witness = circuit + .forward(&mut input, None, None) + .map_err(|e| JsError::new(&format!("{}", e)))?; + + serde_json::to_vec(&witness) + .map_err(|e| JsError::new(&format!("Failed to serialize witness: {}", e))) +} + +/// Generate verifying key in browser +#[wasm_bindgen] +#[allow(non_snake_case)] +pub fn genVk( + compiled_circuit: wasm_bindgen::Clamped>, + params_ser: wasm_bindgen::Clamped>, +) -> Result, JsError> { + // Read in kzg params + let mut reader = std::io::BufReader::new(¶ms_ser[..]); + let params: ParamsKZG = + halo2_proofs::poly::commitment::Params::<'_, G1Affine>::read(&mut reader) + .map_err(|e| JsError::new(&format!("Failed to deserialize params: {}", e)))?; + // Read in compiled circuit + let circuit: crate::graph::GraphCircuit = bincode::deserialize(&compiled_circuit[..]) + .map_err(|e| JsError::new(&format!("Failed to deserialize compiled model: {}", e)))?; + + // Create verifying key + let vk = create_vk_wasm::, Fr, GraphCircuit>(&circuit, ¶ms) + .map_err(Box::::from) + .map_err(|e| JsError::new(&format!("Failed to create verifying key: {}", e)))?; + + let mut serialized_vk = Vec::new(); + vk.write(&mut serialized_vk, halo2_proofs::SerdeFormat::RawBytes) + .map_err(|e| JsError::new(&format!("Failed to serialize vk: {}", e)))?; + + Ok(serialized_vk) +} + +/// Generate proving key in browser +#[wasm_bindgen] +#[allow(non_snake_case)] +pub fn genPk( + vk: wasm_bindgen::Clamped>, + compiled_circuit: wasm_bindgen::Clamped>, + params_ser: wasm_bindgen::Clamped>, +) -> Result, JsError> { + // Read in kzg params + let mut reader = std::io::BufReader::new(¶ms_ser[..]); + let params: ParamsKZG = + halo2_proofs::poly::commitment::Params::<'_, G1Affine>::read(&mut reader) + .map_err(|e| JsError::new(&format!("Failed to deserialize params: {}", e)))?; + // Read in compiled circuit + let circuit: crate::graph::GraphCircuit = bincode::deserialize(&compiled_circuit[..]) + .map_err(|e| JsError::new(&format!("Failed to deserialize compiled model: {}", e)))?; + + // Read in verifying key + let mut reader = std::io::BufReader::new(&vk[..]); + let vk = VerifyingKey::::read::<_, GraphCircuit>( + &mut reader, + halo2_proofs::SerdeFormat::RawBytes, + circuit.settings().clone(), + ) + .map_err(|e| JsError::new(&format!("Failed to deserialize verifying key: {}", e)))?; + // Create proving key + let pk = create_pk_wasm::, Fr, GraphCircuit>(vk, &circuit, ¶ms) + .map_err(Box::::from) + .map_err(|e| JsError::new(&format!("Failed to create proving key: {}", e)))?; + + let mut serialized_pk = Vec::new(); + pk.write(&mut serialized_pk, halo2_proofs::SerdeFormat::RawBytes) + .map_err(|e| JsError::new(&format!("Failed to serialize pk: {}", e)))?; + + Ok(serialized_pk) +} + +/// Verify proof in browser using wasm +#[wasm_bindgen] +pub fn verify( + proof_js: wasm_bindgen::Clamped>, + vk: wasm_bindgen::Clamped>, + settings: wasm_bindgen::Clamped>, + srs: wasm_bindgen::Clamped>, +) -> Result { + let mut reader = std::io::BufReader::new(&srs[..]); + let params: ParamsKZG = + halo2_proofs::poly::commitment::Params::<'_, G1Affine>::read(&mut reader) + .map_err(|e| JsError::new(&format!("Failed to deserialize params: {}", e)))?; + + let circuit_settings: GraphSettings = serde_json::from_slice(&settings[..]) + .map_err(|e| JsError::new(&format!("Failed to deserialize settings: {}", e)))?; + + let snark: crate::pfsys::Snark = serde_json::from_slice(&proof_js[..]) + .map_err(|e| JsError::new(&format!("Failed to deserialize proof: {}", e)))?; + + let mut reader = std::io::BufReader::new(&vk[..]); + let vk = VerifyingKey::::read::<_, GraphCircuit>( + &mut reader, + halo2_proofs::SerdeFormat::RawBytes, + circuit_settings, + ) + .map_err(|e| JsError::new(&format!("Failed to deserialize vk: {}", e)))?; + + let strategy = KZGSingleStrategy::new(params.verifier_params()); + + let result = verify_proof_circuit_kzg(params.verifier_params(), snark, &vk, strategy); + + match result { + Ok(_) => Ok(true), + Err(e) => Err(JsError::new(&format!("{}", e))), + } +} + +/// Prove in browser using wasm +#[wasm_bindgen] +pub fn prove( + witness: wasm_bindgen::Clamped>, + pk: wasm_bindgen::Clamped>, + compiled_circuit: wasm_bindgen::Clamped>, + srs: wasm_bindgen::Clamped>, +) -> Result, JsError> { + #[cfg(feature = "det-prove")] + log::set_max_level(log::LevelFilter::Debug); + #[cfg(not(feature = "det-prove"))] + log::set_max_level(log::LevelFilter::Info); + // read in kzg params + let mut reader = std::io::BufReader::new(&srs[..]); + let params: ParamsKZG = + halo2_proofs::poly::commitment::Params::<'_, G1Affine>::read(&mut reader) + .map_err(|e| JsError::new(&format!("Failed to deserialize srs: {}", e)))?; + + // read in circuit + let mut circuit: crate::graph::GraphCircuit = bincode::deserialize(&compiled_circuit[..]) + .map_err(|e| JsError::new(&format!("Failed to deserialize circuit: {}", e)))?; + + // read in model input + let data: crate::graph::GraphWitness = serde_json::from_slice(&witness[..]) + .map_err(|e| JsError::new(&format!("Failed to deserialize witness: {}", e)))?; + + // read in proving key + let mut reader = std::io::BufReader::new(&pk[..]); + let pk = ProvingKey::::read::<_, GraphCircuit>( + &mut reader, + halo2_proofs::SerdeFormat::RawBytes, + circuit.settings().clone(), + ) + .map_err(|e| JsError::new(&format!("Failed to deserialize proving key: {}", e)))?; + + // prep public inputs + circuit + .load_graph_witness(&data) + .map_err(|e| JsError::new(&format!("{}", e)))?; + let public_inputs = circuit + .prepare_public_inputs(&data) + .map_err(|e| JsError::new(&format!("{}", e)))?; + let proof_split_commits: Option = data.into(); + + let strategy = KZGSingleStrategy::new(¶ms); + let proof = create_proof_circuit_kzg( + circuit, + ¶ms, + Some(public_inputs), + &pk, + crate::pfsys::TranscriptType::EVM, + strategy, + crate::circuit::CheckMode::UNSAFE, + proof_split_commits, + ) + .map_err(|e| JsError::new(&format!("{}", e)))?; + + Ok(serde_json::to_string(&proof) + .map_err(|e| JsError::new(&format!("{}", e)))? + .into_bytes()) +} + +/// print hex representation of a proof +#[wasm_bindgen] +#[allow(non_snake_case)] +pub fn printProofHex(proof: wasm_bindgen::Clamped>) -> Result { + let proof: crate::pfsys::Snark = serde_json::from_slice(&proof[..]) + .map_err(|e| JsError::new(&format!("Failed to deserialize proof: {}", e)))?; + Ok(hex::encode(proof.proof)) +} +// VALIDATION FUNCTIONS + +/// Witness file validation +#[wasm_bindgen] +#[allow(non_snake_case)] +pub fn witnessValidation(witness: wasm_bindgen::Clamped>) -> Result { + let _: crate::graph::GraphWitness = serde_json::from_slice(&witness[..]) + .map_err(|e| JsError::new(&format!("Failed to deserialize witness: {}", e)))?; + + Ok(true) +} +/// Compiled circuit validation +#[wasm_bindgen] +#[allow(non_snake_case)] +pub fn compiledCircuitValidation( + compiled_circuit: wasm_bindgen::Clamped>, +) -> Result { + let _: crate::graph::GraphCircuit = bincode::deserialize(&compiled_circuit[..]) + .map_err(|e| JsError::new(&format!("Failed to deserialize compiled circuit: {}", e)))?; + + Ok(true) +} +/// Input file validation +#[wasm_bindgen] +#[allow(non_snake_case)] +pub fn inputValidation(input: wasm_bindgen::Clamped>) -> Result { + let _: crate::graph::input::GraphData = serde_json::from_slice(&input[..]) + .map_err(|e| JsError::new(&format!("Failed to deserialize input: {}", e)))?; + + Ok(true) +} +/// Proof file validation +#[wasm_bindgen] +#[allow(non_snake_case)] +pub fn proofValidation(proof: wasm_bindgen::Clamped>) -> Result { + let _: crate::pfsys::Snark = serde_json::from_slice(&proof[..]) + .map_err(|e| JsError::new(&format!("Failed to deserialize proof: {}", e)))?; + + Ok(true) +} +/// Vk file validation +#[wasm_bindgen] +#[allow(non_snake_case)] +pub fn vkValidation( + vk: wasm_bindgen::Clamped>, + settings: wasm_bindgen::Clamped>, +) -> Result { + let circuit_settings: GraphSettings = serde_json::from_slice(&settings[..]) + .map_err(|e| JsError::new(&format!("Failed to deserialize settings: {}", e)))?; + let mut reader = std::io::BufReader::new(&vk[..]); + let _ = VerifyingKey::::read::<_, GraphCircuit>( + &mut reader, + halo2_proofs::SerdeFormat::RawBytes, + circuit_settings, + ) + .map_err(|e| JsError::new(&format!("Failed to deserialize vk: {}", e)))?; + + Ok(true) +} +/// Pk file validation +#[wasm_bindgen] +#[allow(non_snake_case)] +pub fn pkValidation( + pk: wasm_bindgen::Clamped>, + settings: wasm_bindgen::Clamped>, +) -> Result { + let circuit_settings: GraphSettings = serde_json::from_slice(&settings[..]) + .map_err(|e| JsError::new(&format!("Failed to deserialize settings: {}", e)))?; + let mut reader = std::io::BufReader::new(&pk[..]); + let _ = ProvingKey::::read::<_, GraphCircuit>( + &mut reader, + halo2_proofs::SerdeFormat::RawBytes, + circuit_settings, + ) + .map_err(|e| JsError::new(&format!("Failed to deserialize proving key: {}", e)))?; + + Ok(true) +} +/// Settings file validation +#[wasm_bindgen] +#[allow(non_snake_case)] +pub fn settingsValidation(settings: wasm_bindgen::Clamped>) -> Result { + let _: GraphSettings = serde_json::from_slice(&settings[..]) + .map_err(|e| JsError::new(&format!("Failed to deserialize settings: {}", e)))?; + + Ok(true) +} +/// Srs file validation +#[wasm_bindgen] +#[allow(non_snake_case)] +pub fn srsValidation(srs: wasm_bindgen::Clamped>) -> Result { + let mut reader = std::io::BufReader::new(&srs[..]); + let _: ParamsKZG = + halo2_proofs::poly::commitment::Params::<'_, G1Affine>::read(&mut reader) + .map_err(|e| JsError::new(&format!("Failed to deserialize params: {}", e)))?; + + Ok(true) +} + +// HELPER FUNCTIONS + +/// Creates a [ProvingKey] for a [GraphCircuit] (`circuit`) with specific [CommitmentScheme] parameters (`params`) for the WASM target +#[cfg(target_arch = "wasm32")] +pub fn create_vk_wasm>( + circuit: &C, + params: &'_ Scheme::ParamsProver, +) -> Result, halo2_proofs::plonk::Error> +where + C: Circuit, + ::Scalar: FromUniformBytes<64>, +{ + // Real proof + let empty_circuit = >::without_witnesses(circuit); + + // Initialize the verifying key + let vk = keygen_vk(params, &empty_circuit)?; + Ok(vk) +} +/// Creates a [ProvingKey] from a [VerifyingKey] for a [GraphCircuit] (`circuit`) with specific [CommitmentScheme] parameters (`params`) for the WASM target +#[cfg(target_arch = "wasm32")] +pub fn create_pk_wasm>( + vk: VerifyingKey, + circuit: &C, + params: &'_ Scheme::ParamsProver, +) -> Result, halo2_proofs::plonk::Error> +where + C: Circuit, + ::Scalar: FromUniformBytes<64>, +{ + // Real proof + let empty_circuit = >::without_witnesses(circuit); + + // Initialize the proving key + let pk = keygen_pk(params, vk, &empty_circuit)?; + Ok(pk) +} + +/// +pub fn u8_array_to_u128_le(arr: [u8; 16]) -> u128 { + let mut n: u128 = 0; + for &b in arr.iter().rev() { + n <<= 8; + n |= b as u128; + } + n +} diff --git a/mnist_zkml/Cargo.lock b/mnist_zkml/Cargo.lock new file mode 100644 index 0000000..eb49478 --- /dev/null +++ b/mnist_zkml/Cargo.lock @@ -0,0 +1,712 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 3 + +[[package]] +name = "arrayref" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6b4930d2cb77ce62f89ee5d5289b4ac049559b1c45539271f5ed4fdc7db34545" + +[[package]] +name = "arrayvec" +version = "0.7.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "96d30a06541fbafbc7f82ed10c06164cfbd2c401138f6addd8404629c4b16711" + +[[package]] +name = "autocfg" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" + +[[package]] +name = "bitvec" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1bc2832c24239b0141d5674bb9174f9d68a8b5b3f2753311927c172ca46f7e9c" +dependencies = [ + "funty", + "radium", + "tap", + "wyz", +] + +[[package]] +name = "blake2b_simd" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "23285ad32269793932e830392f2fe2f83e26488fd3ec778883a93c8323735780" +dependencies = [ + "arrayref", + "arrayvec", + "constant_time_eq", +] + +[[package]] +name = "block-buffer" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4152116fd6e9dadb291ae18fc1ec3575ed6d84c29642d97890f4b4a3417297e4" +dependencies = [ + "block-padding", + "generic-array", +] + +[[package]] +name = "block-padding" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8d696c370c750c948ada61c69a0ee2cbbb9c50b1019ddb86d9317157a99c2cae" + +[[package]] +name = "byteorder" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" + +[[package]] +name = "cfg-if" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" + +[[package]] +name = "constant_time_eq" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f7144d30dcf0fafbce74250a3963025d8d52177934239851c917d29f1df280c2" + +[[package]] +name = "cpufeatures" +version = "0.2.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ce420fe07aecd3e67c5f910618fe65e94158f6dcc0adf44e00d69ce2bdfe0fd0" +dependencies = [ + "libc", +] + +[[package]] +name = "crossbeam-deque" +version = "0.8.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ce6fd6f855243022dcecf8702fef0c297d4338e226845fe067f6341ad9fa0cef" +dependencies = [ + "cfg-if", + "crossbeam-epoch", + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-epoch" +version = "0.9.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ae211234986c545741a7dc064309f67ee1e5ad243d0e48335adc0484d960bcc7" +dependencies = [ + "autocfg", + "cfg-if", + "crossbeam-utils", + "memoffset", + "scopeguard", +] + +[[package]] +name = "crossbeam-utils" +version = "0.8.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a22b2d63d4d1dc0b7f1b6b2747dd0088008a9be28b6ddf0b1e7d335e3037294" +dependencies = [ + "cfg-if", +] + +[[package]] +name = "crunchy" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7a81dae078cea95a014a339291cec439d2f232ebe854a9d672b796c6afafa9b7" + +[[package]] +name = "digest" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d3dd60d1080a57a05ab032377049e0591415d2b31afd7028356dbf3cc6dcb066" +dependencies = [ + "generic-array", +] + +[[package]] +name = "either" +version = "1.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a26ae43d7bcc3b814de94796a5e736d4029efb0ee900c12e2d54c993ad1a1e07" + +[[package]] +name = "ff" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ded41244b729663b1e574f1b4fb731469f69f79c17667b5d776b16cda0479449" +dependencies = [ + "bitvec", + "rand_core", + "subtle", +] + +[[package]] +name = "funty" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6d5a32815ae3f33302d95fdcb2ce17862f8c65363dcfd29360480ba1001fc9c" + +[[package]] +name = "generic-array" +version = "0.14.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85649ca51fd72272d7821adaf274ad91c288277713d9c18820d8499a7ff69e9a" +dependencies = [ + "typenum", + "version_check", +] + +[[package]] +name = "getrandom" +version = "0.2.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fe9006bed769170c11f845cf00c7c1e9092aeb3f268e007c3e760ac68008070f" +dependencies = [ + "cfg-if", + "libc", + "wasi", +] + +[[package]] +name = "group" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f0f9ef7462f7c099f518d754361858f86d8a07af53ba9af0fe635bbccb151a63" +dependencies = [ + "ff", + "rand_core", + "subtle", +] + +[[package]] +name = "halo2" +version = "0.1.0-beta.2" +source = "git+https://github.com/privacy-scaling-explorations/halo2?rev=17e9765c199670534c0299c96128d0464a188d0b#17e9765c199670534c0299c96128d0464a188d0b" +dependencies = [ + "halo2_proofs", +] + +[[package]] +name = "halo2_gadgets" +version = "0.2.0" +source = "git+https://github.com/privacy-scaling-explorations/halo2?rev=17e9765c199670534c0299c96128d0464a188d0b#17e9765c199670534c0299c96128d0464a188d0b" +dependencies = [ + "arrayvec", + "bitvec", + "ff", + "group", + "halo2_proofs", + "halo2curves", + "lazy_static", + "rand", + "subtle", + "uint", +] + +[[package]] +name = "halo2_proofs" +version = "0.2.0" +source = "git+https://github.com/privacy-scaling-explorations/halo2?rev=17e9765c199670534c0299c96128d0464a188d0b#17e9765c199670534c0299c96128d0464a188d0b" +dependencies = [ + "blake2b_simd", + "ff", + "group", + "halo2curves", + "rand_chacha", + "rand_core", + "rayon", + "sha3", + "tracing", +] + +[[package]] +name = "halo2curves" +version = "0.3.2" +source = "git+https://github.com/privacy-scaling-explorations/halo2curves?tag=0.3.2#9f5c50810bbefe779ee5cf1d852b2fe85dc35d5e" +dependencies = [ + "ff", + "group", + "lazy_static", + "num-bigint", + "num-traits", + "pasta_curves", + "paste", + "rand", + "rand_core", + "static_assertions", + "subtle", +] + +[[package]] +name = "hex" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70" + +[[package]] +name = "itoa" +version = "1.0.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "af150ab688ff2122fcef229be89cb50dd66af9e01a4ff320cc137eecc9bacc38" + +[[package]] +name = "keccak" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f6d5ed8676d904364de097082f4e7d240b571b67989ced0240f08b7f966f940" +dependencies = [ + "cpufeatures", +] + +[[package]] +name = "lazy_static" +version = "1.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" +dependencies = [ + "spin", +] + +[[package]] +name = "libc" +version = "0.2.150" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "89d92a4743f9a61002fae18374ed11e7973f530cb3a3255fb354818118b2203c" + +[[package]] +name = "matrixmultiply" +version = "0.3.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7574c1cf36da4798ab73da5b215bbf444f50718207754cb522201d78d1cd0ff2" +dependencies = [ + "autocfg", + "rawpointer", +] + +[[package]] +name = "memoffset" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a634b1c61a95585bd15607c6ab0c4e5b226e695ff2800ba0cdccddf208c406c" +dependencies = [ + "autocfg", +] + +[[package]] +name = "ndarray" +version = "0.15.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "adb12d4e967ec485a5f71c6311fe28158e9d6f4bc4a447b474184d0f91a8fa32" +dependencies = [ + "matrixmultiply", + "num-complex", + "num-integer", + "num-traits", + "rawpointer", +] + +[[package]] +name = "num-bigint" +version = "0.4.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "608e7659b5c3d7cba262d894801b9ec9d00de989e8a82bd4bef91d08da45cdc0" +dependencies = [ + "autocfg", + "num-integer", + "num-traits", +] + +[[package]] +name = "num-complex" +version = "0.4.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1ba157ca0885411de85d6ca030ba7e2a83a28636056c7c699b07c8b6f7383214" +dependencies = [ + "num-traits", +] + +[[package]] +name = "num-integer" +version = "0.1.45" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "225d3389fb3509a24c93f5c29eb6bde2586b98d9f016636dff58d7c6f7569cd9" +dependencies = [ + "autocfg", + "num-traits", +] + +[[package]] +name = "num-traits" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "39e3200413f237f41ab11ad6d161bc7239c84dcb631773ccd7de3dfe4b5c267c" +dependencies = [ + "autocfg", +] + +[[package]] +name = "once_cell" +version = "1.18.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dd8b5dd2ae5ed71462c540258bedcb51965123ad7e7ccf4b9a8cafaa4a63576d" + +[[package]] +name = "opaque-debug" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "624a8340c38c1b80fd549087862da4ba43e08858af025b236e509b6649fc13d5" + +[[package]] +name = "pasta_curves" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d3e57598f73cc7e1b2ac63c79c517b31a0877cd7c402cdcaa311b5208de7a095" +dependencies = [ + "blake2b_simd", + "ff", + "group", + "lazy_static", + "rand", + "static_assertions", + "subtle", +] + +[[package]] +name = "paste" +version = "1.0.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "de3145af08024dea9fa9914f381a17b8fc6034dfb00f3a84013f7ff43f29ed4c" + +[[package]] +name = "pin-project-lite" +version = "0.2.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8afb450f006bf6385ca15ef45d71d2288452bc3683ce2e2cacc0d18e4be60b58" + +[[package]] +name = "ppv-lite86" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" + +[[package]] +name = "proc-macro2" +version = "1.0.70" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "39278fbbf5fb4f646ce651690877f89d1c5811a3d4acb27700c1cb3cdb78fd3b" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "quote" +version = "1.0.33" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5267fca4496028628a95160fc423a33e8b2e6af8a5302579e322e4b520293cae" +dependencies = [ + "proc-macro2", +] + +[[package]] +name = "radium" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc33ff2d4973d518d823d61aa239014831e521c75da58e3df4840d3f47749d09" + +[[package]] +name = "rand" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" +dependencies = [ + "libc", + "rand_chacha", + "rand_core", +] + +[[package]] +name = "rand_chacha" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" +dependencies = [ + "ppv-lite86", + "rand_core", +] + +[[package]] +name = "rand_core" +version = "0.6.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" +dependencies = [ + "getrandom", +] + +[[package]] +name = "rawpointer" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3" + +[[package]] +name = "rayon" +version = "1.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9c27db03db7734835b3f53954b534c91069375ce6ccaa2e065441e07d9b6cdb1" +dependencies = [ + "either", + "rayon-core", +] + +[[package]] +name = "rayon-core" +version = "1.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5ce3fb6ad83f861aac485e76e1985cd109d9a3713802152be56c3b1f0e0658ed" +dependencies = [ + "crossbeam-deque", + "crossbeam-utils", +] + +[[package]] +name = "riff" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b9b1a3d5f46d53f4a3478e2be4a5a5ce5108ea58b100dcd139830eae7f79a3a1" + +[[package]] +name = "rmp" +version = "0.8.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f9860a6cc38ed1da53456442089b4dfa35e7cedaa326df63017af88385e6b20" +dependencies = [ + "byteorder", + "num-traits", + "paste", +] + +[[package]] +name = "rmp-serde" +version = "1.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bffea85eea980d8a74453e5d02a8d93028f3c34725de143085a844ebe953258a" +dependencies = [ + "byteorder", + "rmp", + "serde", +] + +[[package]] +name = "rounded-div" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "464c8fb0a126d6a0326baf6abf1aa62c2da0d5780aa781a81451d64f543f5e2f" + +[[package]] +name = "ryu" +version = "1.0.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1ad4cc8da4ef723ed60bced201181d83791ad433213d8c24efffda1eec85d741" + +[[package]] +name = "scopeguard" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" + +[[package]] +name = "serde" +version = "1.0.193" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "25dd9975e68d0cb5aa1120c288333fc98731bd1dd12f561e468ea4728c042b89" +dependencies = [ + "serde_derive", +] + +[[package]] +name = "serde_derive" +version = "1.0.193" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "43576ca501357b9b071ac53cdc7da8ef0cbd9493d8df094cd821777ea6e894d3" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "serde_json" +version = "1.0.108" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3d1c7e3eac408d115102c4c24ad393e0821bb3a5df4d506a80f85f7a742a526b" +dependencies = [ + "itoa", + "ryu", + "serde", +] + +[[package]] +name = "sha3" +version = "0.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f81199417d4e5de3f04b1e871023acea7389672c4135918f05aa9cbf2f2fa809" +dependencies = [ + "block-buffer", + "digest", + "keccak", + "opaque-debug", +] + +[[package]] +name = "spin" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6e63cff320ae2c57904679ba7cb63280a3dc4613885beafb148ee7bf9aa9042d" + +[[package]] +name = "static_assertions" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f" + +[[package]] +name = "subtle" +version = "2.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "81cdd64d312baedb58e21336b31bc043b77e01cc99033ce76ef539f78e965ebc" + +[[package]] +name = "syn" +version = "2.0.39" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "23e78b90f2fcf45d3e842032ce32e3f2d1545ba6636271dcbf24fa306d87be7a" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "tap" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "55937e1799185b12863d447f42597ed69d9928686b8d88a1df17376a097d8369" + +[[package]] +name = "tracing" +version = "0.1.40" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c3523ab5a71916ccf420eebdf5521fcef02141234bbc0b8a49f2fdc4544364ef" +dependencies = [ + "pin-project-lite", + "tracing-attributes", + "tracing-core", +] + +[[package]] +name = "tracing-attributes" +version = "0.1.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "tracing-core" +version = "0.1.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c06d3da6113f116aaee68e4d601191614c9053067f9ab7f6edbcb161237daa54" +dependencies = [ + "once_cell", +] + +[[package]] +name = "typenum" +version = "1.17.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825" + +[[package]] +name = "uint" +version = "0.9.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "76f64bba2c53b04fcab63c01a7d7427eadc821e3bc48c34dc9ba29c501164b52" +dependencies = [ + "byteorder", + "crunchy", + "hex", + "static_assertions", +] + +[[package]] +name = "unicode-ident" +version = "1.0.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b" + +[[package]] +name = "version_check" +version = "0.9.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f" + +[[package]] +name = "wasi" +version = "0.11.0+wasi-snapshot-preview1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" + +[[package]] +name = "wav" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a65e199c799848b4f997072aa4d673c034f80f40191f97fe2f0a23f410be1609" +dependencies = [ + "riff", +] + +[[package]] +name = "wyz" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "05f360fc0b24296329c78fda852a1e9ae82de9cf7b27dae4b7f62f118f77b9ed" +dependencies = [ + "tap", +] + +[[package]] +name = "zkml" +version = "0.0.1" +dependencies = [ + "bitvec", + "halo2", + "halo2_gadgets", + "halo2_proofs", + "lazy_static", + "ndarray", + "num-bigint", + "num-traits", + "once_cell", + "rand", + "rmp-serde", + "rounded-div", + "serde", + "serde_derive", + "serde_json", + "wav", +] diff --git a/mnist_zkml/Cargo.toml b/mnist_zkml/Cargo.toml new file mode 100644 index 0000000..43286f2 --- /dev/null +++ b/mnist_zkml/Cargo.toml @@ -0,0 +1,39 @@ +[package] +name = "zkml" +version = "0.0.1" +edition = "2021" +description = "Zero-knowledge machine learning" +license = "LICENSE" +homepage = "https://github.com/ddkang/zkml" +repository = "https://github.com/ddkang/zkml-public.git" +readme = "README.md" +exclude = [ + "params", + "params_kzg", + "python", +] + +[profile.dev] +opt-level = 3 + +[profile.test] +opt-level = 3 + +[dependencies] +bitvec = "1.0.1" +halo2 = { git="https://github.com/privacy-scaling-explorations/halo2", package="halo2", rev="17e9765c199670534c0299c96128d0464a188d0b" } +halo2_gadgets = { git="https://github.com/privacy-scaling-explorations/halo2", package="halo2_gadgets", rev="17e9765c199670534c0299c96128d0464a188d0b", features = ["circuit-params"] } +halo2_proofs = { git="https://github.com/privacy-scaling-explorations/halo2", package="halo2_proofs", rev="17e9765c199670534c0299c96128d0464a188d0b", features = ["circuit-params"] } +lazy_static = "1.4.0" +ndarray = "0.15.6" +num-bigint = "0.4.3" +num-traits = "0.2.15" +once_cell = "1.15.0" +rand = "0.8.5" +rmp-serde = "1.1.1" +rounded-div = "0.1.2" +serde = "1.0.152" +serde_derive = "1.0.152" +serde_json = "1.0.85" +wav = "1.0.0" + diff --git a/mnist_zkml/benches/bench.rs b/mnist_zkml/benches/bench.rs new file mode 100644 index 0000000..393f183 --- /dev/null +++ b/mnist_zkml/benches/bench.rs @@ -0,0 +1,236 @@ +// [existing imports] + +// Additional imports for handling MNIST data +use mnist::{Mnist, MnistBuilder}; +use std::{collections::HashMap}; +use tch::{Tensor, Kind, Device, vision, Scalar}; +// Constants for MNIST data +const IMAGE_SIZE: usize = 28 * 28; // MNIST images are 28x28 +// ... other constants as needed + + +struct MNISTModel { + l1: Linear, + l2: Linear, +} + +impl MNISTModel { + fn new (mem: &mut Memory) -> MNISTModel { + let l1 = Linear::new(mem, 784, 128); + let l2 = Linear::new(mem, 128, 10); + Self { + l1: l1, + l2: l2, + } + } +} + +impl Compute for MNISTModel { + fn forward (&self, mem: &Memory, input: &Tensor) -> Tensor { + let mut o = self.l1.forward(mem, input); + o = o.relu(); + o = self.l2.forward(mem, &o); + o + } +} + +fn main() { + let (x, y) = load_mnist(); + + let mut m = Memory::new(); + let mnist_model = MNISTModel::new(&mut m); + train(&mut m, &x, &y, &mnist_model, 100, 128, cross_entropy, 0.3); + let out = mnist_model.forward(&m, &x); + println!("Training Accuracy: {}", accuracy(&y, &out)); +} + +trait Compute { + fn forward (&self, mem: &Memory, input: &Tensor) -> Tensor; +} + +struct Linear { + params: HashMap, +} + +impl Linear { + fn new (mem: &mut Memory, ninputs: i64, noutputs: i64) -> Self { + let mut p = HashMap::new(); + p.insert("W".to_string(), mem.new_push(&[ninputs,noutputs], true)); + p.insert("b".to_string(), mem.new_push(&[1, noutputs], true)); + + Self { + params: p, + } + } +} + +impl Compute for Linear { + fn forward (&self, mem: &Memory, input: &Tensor) -> Tensor { + let w = mem.get(self.params.get(&"W".to_string()).unwrap()); + let b = mem.get(self.params.get(&"b".to_string()).unwrap()); + input.matmul(w) + b + } +} + +fn mse(target: &Tensor, pred: &Tensor) -> Tensor { + (target - pred).square().mean(Kind::Float) +} + +fn cross_entropy (target: &Tensor, pred: &Tensor) -> Tensor { + let loss = pred.log_softmax(-1, Kind::Float).nll_loss(target); + loss +} + +struct Memory { + size: usize, + values: Vec, +} + +impl Memory { + + fn new() -> Self { + let v = Vec::new(); + Self {size: 0, + values: v} + } + + fn push (&mut self, value: Tensor) -> usize { + self.values.push(value); + self.size += 1; + self.size-1 + } + + fn new_push (&mut self, size: &[i64], requires_grad: bool) -> usize { + let t = Tensor::randn(size, (Kind::Float, Device::Cpu)).requires_grad_(requires_grad); + self.push(t) + } + + fn get (&self, addr: &usize) -> &Tensor { + &self.values[*addr] + } + + fn apply_grads_sgd(&mut self, learning_rate: f32) { + let mut g = Tensor::new(); + self.values + .iter_mut() + .for_each(|t| { + if t.requires_grad() { + g = t.grad(); + t.set_data(&(t.data() - learning_rate*&g)); + t.zero_grad(); + } + }); + } + + fn apply_grads_sgd_momentum(&mut self, learning_rate: f32) { + let mut g: Tensor = Tensor::new(); + let mut velocity: Vec= Tensor::zeros(&[self.size as i64], (Kind::Float, Device::Cpu)).split(1, 0); + let mut vcounter = 0; + const BETA:f32 = 0.9; + + self.values + .iter_mut() + .for_each(|t| { + if t.requires_grad() { + g = t.grad(); + velocity[vcounter] = BETA * &velocity[vcounter] + (1.0 - BETA) * &g; + t.set_data(&(t.data() - learning_rate * &velocity[vcounter])); + t.zero_grad(); + } + vcounter += 1; + }); + } +} + +fn train(mem: &mut Memory, x: &Tensor, y: &Tensor, model: &dyn Compute, epochs: i64, batch_size: i64, errfunc: F, learning_rate: f32) + where F: Fn(&Tensor, &Tensor)-> Tensor + { + let mut error = Tensor::from(0.0); + let mut batch_error = Tensor::from(0.0); + let mut pred = Tensor::from(0.0); + for epoch in 0..epochs { + batch_error = Tensor::from(0.0); + for (batchx, batchy) in get_batches(&x, &y, batch_size, true) { + pred = model.forward(mem, &batchx); + error = errfunc(&batchy, &pred); + batch_error += error.detach(); + error.backward(); + mem.apply_grads_sgd_momentum(learning_rate); + } + println!("Epoch: {:?} Error: {:?}", epoch, batch_error/batch_size); + } +} + +fn get_batches(x: &Tensor, y: &Tensor, batch_size: i64, shuffle: bool) -> impl Iterator { + let num_rows = x.size()[0]; + let num_batches = (num_rows + batch_size - 1) / batch_size; + + let indices = if shuffle { + Tensor::randperm(num_rows as i64, (Kind::Int64, Device::Cpu)) + } else + { + let rng = (0..num_rows).collect::>(); + Tensor::from_slice(&rng) + }; + let x = x.index_select(0, &indices); + let y = y.index_select(0, &indices); + + (0..num_batches).map(move |i| { + let start = i * batch_size; + let end = (start + batch_size).min(num_rows); + let batchx: Tensor = x.narrow(0, start, end - start); + let batchy: Tensor = y.narrow(0, start, end - start); + (batchx, batchy) + }) +} + + +fn load_mnist() -> (Tensor, Tensor) { + let data = MnistBuilder::new() + .label_format_digit() + .training_set_length(1000) + .validation_set_length(1000) + .finalize(); + + let train_images = Tensor::from_slice(&data.trn_img); + let val_images = Tensor::from_slice(&data.trn_lbl); + let x = train_images; + let y = val_images; + (x, y) +} + +fn accuracy(target: &Tensor, pred: &Tensor) -> f64 { + let yhat = pred.argmax(1,true).squeeze(); + let eq = target.eq_tensor(&yhat); + let accuracy: f64 = (eq.sum(Kind::Int64) / target.size()[0]).double_value(&[]).into(); + accuracy +} + +// #[benchy::benchmark] +// fn run_mnist(bench: &mut BenchmarkRun) { +// // Load MNIST data +// let (mnist_train, mnist_test) = MnistBuilder::new() +// .label_format_digit() +// .finalize(); + +// for &k in [8, 10, 11, 12, 13, 14].iter() { +// // ... + +// // Use MNIST data for the tensors +// let mut image_tensor = Tensor::from(mnist_train.trn_img.iter().map(|&v| Value::known(Fr::from(v as u64)))); +// image_tensor.reshape(&[IMAGE_SIZE, 1]).unwrap(); +// let mut label_tensor = Tensor::from(mnist_train.trn_lbl.iter().map(|&v| Value::known(Fr::from(v as u64)))); +// label_tensor.reshape(&[1, 1]).unwrap(); + +// let circuit = MNISTCircuit { +// image: ValTensor::from(image_tensor), +// label: ValTensor::from(label_tensor), +// _marker: PhantomData, +// }; + +// // Benchmarking logic remains the same +// // ... +// } +// } + +// benchy::main!(run_mnist); diff --git a/mnist_zkml/src/bin/test_circuit.rs b/mnist_zkml/src/bin/test_circuit.rs new file mode 100644 index 0000000..6bd66ef --- /dev/null +++ b/mnist_zkml/src/bin/test_circuit.rs @@ -0,0 +1,23 @@ +use halo2_proofs::{dev::MockProver, halo2curves::bn256::Fr}; +use zkml::{ + model::ModelCircuit, + utils::{ + helpers::get_public_values, + loader::{load_model_msgpack, ModelMsgpack}, + }, +}; + +fn main() { + let config_fname = std::env::args().nth(1).expect("config file path"); + let inp_fname = std::env::args().nth(2).expect("input file path"); + + let config: ModelMsgpack = load_model_msgpack(&config_fname, &inp_fname); + + let circuit = ModelCircuit::::generate_from_file(&config_fname, &inp_fname); + + let _prover = MockProver::run(config.k.try_into().unwrap(), &circuit, vec![vec![]]).unwrap(); + let public_vals = get_public_values(); + + let prover = MockProver::run(config.k.try_into().unwrap(), &circuit, vec![public_vals]).unwrap(); + assert_eq!(prover.verify(), Ok(())); +} diff --git a/mnist_zkml/src/bin/time_circuit.rs b/mnist_zkml/src/bin/time_circuit.rs new file mode 100644 index 0000000..f4f4094 --- /dev/null +++ b/mnist_zkml/src/bin/time_circuit.rs @@ -0,0 +1,23 @@ +use halo2_proofs::halo2curves::{bn256::Fr, pasta::Fp}; +use zkml::{ + model::ModelCircuit, + utils::{proving_ipa::time_circuit_ipa, proving_kzg::time_circuit_kzg}, +}; + +fn main() { + let config_fname = std::env::args().nth(1).expect("config file path"); + let inp_fname = std::env::args().nth(2).expect("input file path"); + let kzg_or_ipa = std::env::args().nth(3).expect("kzg or ipa"); + + if kzg_or_ipa != "kzg" && kzg_or_ipa != "ipa" { + panic!("Must specify kzg or ipa"); + } + + if kzg_or_ipa == "kzg" { + let circuit = ModelCircuit::::generate_from_file(&config_fname, &inp_fname); + time_circuit_kzg(circuit); + } else { + let circuit = ModelCircuit::::generate_from_file(&config_fname, &inp_fname); + time_circuit_ipa(circuit); + } +} diff --git a/mnist_zkml/src/bin/verify_circuit.rs b/mnist_zkml/src/bin/verify_circuit.rs new file mode 100644 index 0000000..e0117bc --- /dev/null +++ b/mnist_zkml/src/bin/verify_circuit.rs @@ -0,0 +1,27 @@ +use halo2_proofs::halo2curves::bn256::Fr; +use zkml::{ + model::ModelCircuit, + utils::{loader::load_config_msgpack, proving_kzg::verify_circuit_kzg}, +}; + +fn main() { + let config_fname = std::env::args().nth(1).expect("config file path"); + let vkey_fname = std::env::args().nth(2).expect("verification key file path"); + let proof_fname = std::env::args().nth(3).expect("proof file path"); + let public_vals_fname = std::env::args().nth(4).expect("public values file path"); + let kzg_or_ipa = std::env::args().nth(5).expect("kzg or ipa"); + + if kzg_or_ipa != "kzg" && kzg_or_ipa != "ipa" { + panic!("Must specify kzg or ipa"); + } + + if kzg_or_ipa == "kzg" { + let config = load_config_msgpack(&config_fname); + let circuit = ModelCircuit::::generate_from_msgpack(config, false); + println!("Loaded configuration"); + verify_circuit_kzg(circuit, &vkey_fname, &proof_fname, &public_vals_fname); + } else { + // Serialization of the verification key doesn't seem to be supported for IPA + panic!("Not implemented"); + } +} diff --git a/mnist_zkml/src/bin/verify_wav.rs b/mnist_zkml/src/bin/verify_wav.rs new file mode 100644 index 0000000..bf6d6fe --- /dev/null +++ b/mnist_zkml/src/bin/verify_wav.rs @@ -0,0 +1,46 @@ +use std::fs::File; + +use halo2_proofs::{dev::MockProver, halo2curves::bn256::Fr}; +use zkml::{ + model::ModelCircuit, + utils::{ + helpers::get_public_values, + loader::{load_config_msgpack, ModelMsgpack, TensorMsgpack}, + }, +}; + +fn main() { + let config_fname = std::env::args().nth(1).expect("config file path"); + let wav_fname = std::env::args().nth(2).expect("wav file path"); + + let mut wav_file = File::open(wav_fname).unwrap(); + let (_header, data) = wav::read(&mut wav_file).unwrap(); + let data = match data { + wav::BitDepth::Sixteen(data) => data, + _ => panic!("Unsupported bit depth"), + }; + let data: Vec = data.iter().map(|x| *x as i64).collect(); + + let base_config = load_config_msgpack(&config_fname); + + let config = ModelMsgpack { + tensors: vec![TensorMsgpack { + idx: 0, + shape: vec![1, data.len().try_into().unwrap()], + data: data, + }], + inp_idxes: vec![0], + out_idxes: vec![], + layers: vec![], + commit_before: Some(vec![]), + commit_after: Some(vec![vec![0]]), + ..base_config + }; + println!("Config: {:?}", config); + let k = config.k; + let circuit = ModelCircuit::::generate_from_msgpack(config, false); + + let _prover = MockProver::run(k.try_into().unwrap(), &circuit, vec![vec![]]).unwrap(); + let public_vals: Vec = get_public_values(); + println!("Public values: {:?}", public_vals); +} diff --git a/mnist_zkml/src/commitments.rs b/mnist_zkml/src/commitments.rs new file mode 100644 index 0000000..5f9e665 --- /dev/null +++ b/mnist_zkml/src/commitments.rs @@ -0,0 +1,3 @@ +pub mod commit; +pub mod packer; +pub mod poseidon_commit; diff --git a/mnist_zkml/src/commitments/commit.rs b/mnist_zkml/src/commitments/commit.rs new file mode 100644 index 0000000..305578f --- /dev/null +++ b/mnist_zkml/src/commitments/commit.rs @@ -0,0 +1,16 @@ +use std::{collections::HashMap, rc::Rc}; + +use halo2_proofs::{circuit::Layouter, halo2curves::ff::PrimeField, plonk::Error}; + +use crate::{gadgets::gadget::GadgetConfig, layers::layer::CellRc}; + +pub trait Commit { + fn commit( + &self, + layouter: impl Layouter, + gadget_config: Rc, + constants: &HashMap>, + values: &Vec>, + blinding: CellRc, + ) -> Result>, Error>; +} diff --git a/mnist_zkml/src/commitments/packer.rs b/mnist_zkml/src/commitments/packer.rs new file mode 100644 index 0000000..2ddbaff --- /dev/null +++ b/mnist_zkml/src/commitments/packer.rs @@ -0,0 +1,373 @@ +use std::{ + cmp::{max, min}, + collections::{BTreeMap, HashMap}, + marker::PhantomData, + rc::Rc, +}; + +use halo2_proofs::{ + circuit::{AssignedCell, Layouter, Value}, + halo2curves::ff::PrimeField, + plonk::{ConstraintSystem, Error, Expression}, + poly::Rotation, +}; +use ndarray::{Array, IxDyn}; + +use crate::{ + gadgets::gadget::{GadgetConfig, GadgetType}, + layers::layer::{AssignedTensor, CellRc}, +}; + +const NUM_BITS_PER_FIELD_ELEM: usize = 254; + +pub struct PackerConfig { + pub num_bits_per_elem: usize, + pub num_elem_per_packed: usize, + pub num_packed_per_row: usize, + pub exponents: Vec, + _marker: PhantomData, +} + +pub struct PackerChip { + pub config: PackerConfig, +} + +impl PackerChip { + pub fn get_exponents(num_bits_per_elem: usize, num_exponents: usize) -> Vec { + let mul_val = F::from(1 << num_bits_per_elem); + let mut exponents = vec![F::ONE]; + for _ in 1..num_exponents { + exponents.push(exponents[exponents.len() - 1] * mul_val); + } + exponents + } + + pub fn construct(num_bits_per_elem: usize, gadget_config: &GadgetConfig) -> PackerConfig { + let columns = &gadget_config.columns; + + let num_elem_per_packed = if NUM_BITS_PER_FIELD_ELEM / num_bits_per_elem > columns.len() - 1 { + columns.len() - 1 + } else { + // TODO: for many columns, pack many in a single row + NUM_BITS_PER_FIELD_ELEM / num_bits_per_elem + }; + println!("column len: {}", columns.len()); + println!("num_bits_per_elem: {}", num_bits_per_elem); + println!("NUM_BITS_PER_FIELD_ELEM: {}", NUM_BITS_PER_FIELD_ELEM); + println!("num_elem_per_packed: {}", num_elem_per_packed); + + let num_packed_per_row = max( + 1, + columns.len() / (num_elem_per_packed * (num_bits_per_elem + 1)), + ); + println!("num_packed_per_row: {}", num_packed_per_row); + + let exponents = Self::get_exponents(num_bits_per_elem, num_elem_per_packed); + + let config = PackerConfig { + num_bits_per_elem, + num_elem_per_packed, + num_packed_per_row, + exponents, + _marker: PhantomData, + }; + config + } + + pub fn configure( + meta: &mut ConstraintSystem, + packer_config: PackerConfig, + gadget_config: GadgetConfig, + ) -> GadgetConfig { + let selector = meta.complex_selector(); + let columns = gadget_config.columns; + let lookup = gadget_config.tables.get(&GadgetType::InputLookup).unwrap()[0]; + + let exponents = &packer_config.exponents; + + let num_bits_per_elem = packer_config.num_bits_per_elem; + let shift_val = 1 << (num_bits_per_elem - 1); + let shift_val = Expression::Constant(F::from(shift_val as u64)); + + meta.create_gate("packer", |meta| { + let s = meta.query_selector(selector); + let mut constraints = vec![]; + for i in 0..packer_config.num_packed_per_row { + let offset = i * (packer_config.num_elem_per_packed + 1); + let inps = columns[offset..offset + packer_config.num_elem_per_packed] + .iter() + .map(|col| meta.query_advice(*col, Rotation::cur())) + .collect::>(); + + let outp = meta.query_advice( + columns[offset + packer_config.num_elem_per_packed], + Rotation::cur(), + ); + + let res = inps + .into_iter() + .zip(exponents.iter()) + .map(|(inp, exp)| (inp + shift_val.clone()) * (*exp)) + .fold(Expression::Constant(F::ZERO), |acc, prod| acc + prod); + constraints.push(s.clone() * (res - outp)); + // constraints.push(s.clone() * Expression::Constant(F::zero())); + } + + constraints + }); + + // Ensure that the weights/inputs are in the correct range + for i in 0..packer_config.num_packed_per_row { + let offset = i * (packer_config.num_elem_per_packed + 1); + for j in 0..packer_config.num_elem_per_packed { + meta.lookup("packer lookup", |meta| { + let s = meta.query_selector(selector); + let inp = meta.query_advice(columns[offset + j], Rotation::cur()); + + vec![(s * (inp + shift_val.clone()), lookup)] + }); + } + } + + let mut selectors = gadget_config.selectors; + selectors.insert(GadgetType::Packer, vec![selector]); + + GadgetConfig { + columns, + selectors, + ..gadget_config + } + } + + pub fn copy_and_pack_row( + &self, + mut layouter: impl Layouter, + gadget_config: Rc, + cells: Vec>, + zero: &AssignedCell, + ) -> Result>, Error> { + let columns = &gadget_config.columns; + let selector = gadget_config.selectors.get(&GadgetType::Packer).unwrap()[0]; + + let num_bits_per_elem = gadget_config.num_bits_per_elem; + let shift_val = 1 << (num_bits_per_elem - 1); + let shift_val = F::from(shift_val as u64); + + let outp = layouter.assign_region( + || "pack row", + |mut region| { + if gadget_config.use_selectors { + selector.enable(&mut region, 0)?; + } + + let mut packed = vec![]; + for i in 0..self.config.num_packed_per_row { + let val_offset = i * self.config.num_elem_per_packed; + let col_offset = i * (self.config.num_elem_per_packed + 1); + + let mut vals = cells + [val_offset..min(val_offset + self.config.num_elem_per_packed, cells.len())] + .iter() + .enumerate() + .map(|(i, x)| { + x.copy_advice(|| "", &mut region, columns[col_offset + i], 0) + .unwrap(); + x.value().copied() + }) + .collect::>(); + + let zero_copied = (cells.len()..self.config.num_elem_per_packed) + .map(|i| { + zero + .copy_advice(|| "", &mut region, columns[col_offset + i], 0) + .unwrap(); + zero.value().copied() + }) + .collect::>(); + vals.extend(zero_copied); + + let res = vals.iter().zip(self.config.exponents.iter()).fold( + Value::known(F::ZERO), + |acc, (inp, exp)| { + let res = acc + (*inp + Value::known(shift_val)) * Value::known(*exp); + res + }, + ); + + let outp = region.assign_advice( + || "", + columns[col_offset + self.config.num_elem_per_packed], + 0, + || res, + )?; + packed.push(Rc::new(outp)); + } + + Ok(packed) + }, + )?; + + Ok(outp) + } + + pub fn assign_and_pack_row( + &self, + mut layouter: impl Layouter, + gadget_config: Rc, + values: Vec<&F>, + zero: &AssignedCell, + ) -> Result<(Vec>, Vec>), Error> { + let columns = &gadget_config.columns; + let selector = gadget_config.selectors.get(&GadgetType::Packer).unwrap()[0]; + + let num_bits_per_elem = gadget_config.num_bits_per_elem; + let shift_val = 1 << (num_bits_per_elem - 1); + let shift_val = F::from(shift_val as u64); + + let outp = layouter.assign_region( + || "pack row", + |mut region| { + if gadget_config.use_selectors { + selector.enable(&mut region, 0)?; + } + + let mut packed = vec![]; + let mut assigned = vec![]; + for i in 0..self.config.num_packed_per_row { + let val_offset = i * self.config.num_elem_per_packed; + let col_offset = i * (self.config.num_elem_per_packed + 1); + + let mut values = values + [val_offset..min(val_offset + self.config.num_elem_per_packed, values.len())] + .iter() + .map(|x| **x) + .collect::>(); + let vals = values + .iter() + .enumerate() + .map(|(i, x)| { + let tmp = region + .assign_advice(|| "", columns[col_offset + i], 0, || Value::known(*x)) + .unwrap(); + Rc::new(tmp) + }) + .collect::>(); + assigned.extend(vals); + + let zero_vals = (values.len()..self.config.num_elem_per_packed) + .map(|i| { + zero + .copy_advice(|| "", &mut region, columns[col_offset + i], 0) + .unwrap(); + F::ZERO + }) + .collect::>(); + values.extend(zero_vals); + + let res = + values + .iter() + .zip(self.config.exponents.iter()) + .fold(F::ZERO, |acc, (inp, exp)| { + let res = acc + (*inp + shift_val) * (*exp); + res + }); + + let outp = region.assign_advice( + || "", + columns[col_offset + self.config.num_elem_per_packed], + 0, + || Value::known(res), + )?; + packed.push(Rc::new(outp)); + } + + Ok((packed, assigned)) + }, + )?; + + Ok(outp) + } + + pub fn assign_and_pack( + &self, + mut layouter: impl Layouter, + gadget_config: Rc, + constants: &HashMap>, + tensors: &BTreeMap>, + ) -> Result<(BTreeMap>, Vec>), Error> { + let mut values = vec![]; + for (_, tensor) in tensors { + for value in tensor.iter() { + values.push(value); + } + } + + let mut packed = vec![]; + let mut assigned = vec![]; + let zero = constants.get(&0).unwrap().clone(); + + let num_elems_per_row = self.config.num_packed_per_row * self.config.num_elem_per_packed; + for i in 0..(values.len().div_ceil(num_elems_per_row)) { + let row = + values[i * num_elems_per_row..min((i + 1) * num_elems_per_row, values.len())].to_vec(); + let (row_packed, row_assigned) = self + .assign_and_pack_row( + layouter.namespace(|| "pack row"), + gadget_config.clone(), + row, + zero.as_ref(), + ) + .unwrap(); + packed.extend(row_packed); + assigned.extend(row_assigned); + } + + let mut assigned_tensors = BTreeMap::new(); + let mut start_idx = 0; + for (tensor_id, tensor) in tensors { + let num_el = tensor.len(); + let v = assigned[start_idx..start_idx + num_el].to_vec(); + let new_tensor = Array::from_shape_vec(tensor.raw_dim(), v).unwrap(); + assigned_tensors.insert(*tensor_id, new_tensor); + start_idx += num_el; + } + + Ok((assigned_tensors, packed)) + } + + pub fn copy_and_pack( + &self, + mut layouter: impl Layouter, + gadget_config: Rc, + constants: &HashMap>, + tensors: &BTreeMap>, + ) -> Result>, Error> { + let mut values = vec![]; + for (_, tensor) in tensors { + for value in tensor.iter() { + values.push(value.clone()); + } + } + + let mut packed = vec![]; + let zero = constants.get(&0).unwrap().clone(); + + let num_elems_per_row = self.config.num_packed_per_row * self.config.num_elem_per_packed; + for i in 0..(values.len().div_ceil(num_elems_per_row)) { + let row = + values[i * num_elems_per_row..min((i + 1) * num_elems_per_row, values.len())].to_vec(); + let row_packed = self + .copy_and_pack_row( + layouter.namespace(|| "pack row"), + gadget_config.clone(), + row, + zero.as_ref(), + ) + .unwrap(); + packed.extend(row_packed); + } + + Ok(packed) + } +} diff --git a/mnist_zkml/src/commitments/poseidon_commit.rs b/mnist_zkml/src/commitments/poseidon_commit.rs new file mode 100644 index 0000000..bfaaf13 --- /dev/null +++ b/mnist_zkml/src/commitments/poseidon_commit.rs @@ -0,0 +1,149 @@ +use std::{collections::HashMap, marker::PhantomData, rc::Rc}; + +use halo2_gadgets::poseidon::{ + primitives::{generate_constants, Absorbing, ConstantLength, Domain, Mds, Spec}, + PaddedWord, PoseidonSpongeInstructions, Pow5Chip, Pow5Config, Sponge, +}; +use halo2_proofs::{ + circuit::Layouter, + halo2curves::ff::{FromUniformBytes, PrimeField}, + plonk::{Advice, Column, ConstraintSystem, Error}, +}; + +use crate::{gadgets::gadget::GadgetConfig, layers::layer::CellRc}; + +use super::commit::Commit; + +pub const WIDTH: usize = 3; +pub const RATE: usize = 2; +pub const L: usize = 8 - WIDTH - 1; + +#[derive(Clone, Debug)] + +pub struct PoseidonCommitChip< + F: PrimeField + Ord + FromUniformBytes<64>, + const WIDTH: usize, + const RATE: usize, + const L: usize, +> { + pub poseidon_config: Pow5Config, +} + +#[derive(Debug)] +pub struct P128Pow5T3Gen(PhantomData); + +impl P128Pow5T3Gen { + pub fn new() -> Self { + P128Pow5T3Gen(PhantomData::default()) + } +} + +impl + Ord, const SECURE_MDS: usize> Spec + for P128Pow5T3Gen +{ + fn full_rounds() -> usize { + 8 + } + + fn partial_rounds() -> usize { + 56 + } + + fn sbox(val: F) -> F { + val.pow_vartime([5]) + } + + fn secure_mds() -> usize { + SECURE_MDS + } + + fn constants() -> (Vec<[F; 3]>, Mds, Mds) { + generate_constants::<_, Self, 3, 2>() + } +} + +/// A Poseidon hash function, built around a sponge. +#[derive(Debug)] +pub struct MyHash< + F: PrimeField, + PoseidonChip: PoseidonSpongeInstructions, + S: Spec, + D: Domain, + const T: usize, + const RATE: usize, +> { + pub sponge: Sponge, RATE>, D, T, RATE>, +} + +impl> PoseidonCommitChip { + pub fn configure( + meta: &mut ConstraintSystem, + // TODO: ?? + _input: [Column; L], + state: [Column; WIDTH], + partial_sbox: Column, + ) -> PoseidonCommitChip { + let rc_a = (0..WIDTH).map(|_| meta.fixed_column()).collect::>(); + let rc_b = (0..WIDTH).map(|_| meta.fixed_column()).collect::>(); + + meta.enable_constant(rc_b[0]); + + PoseidonCommitChip { + poseidon_config: Pow5Chip::configure::>( + meta, + state.try_into().unwrap(), + partial_sbox, + rc_a.try_into().unwrap(), + rc_b.try_into().unwrap(), + ), + } + } +} + +impl> Commit + for PoseidonCommitChip +{ + fn commit( + &self, + mut layouter: impl Layouter, + _gadget_config: Rc, + _constants: &HashMap>, + values: &Vec>, + blinding: CellRc, + ) -> Result>, Error> { + let chip = Pow5Chip::construct(self.poseidon_config.clone()); + let mut hasher: MyHash, P128Pow5T3Gen, ConstantLength, 3, 2> = + Sponge::new(chip, layouter.namespace(|| "sponge")) + .map(|sponge| MyHash { sponge }) + .unwrap(); + + let mut new_vals = values + .iter() + .map(|x| x.clone()) + .chain(vec![blinding.clone()]) + .collect::>(); + while new_vals.len() % L != 0 { + new_vals.push(blinding.clone()); + } + for (i, value) in new_vals + .iter() + .map(|x| PaddedWord::Message((**x).clone())) + .chain( as Domain>::padding(L).map(PaddedWord::Padding)) + .enumerate() + { + hasher + .sponge + .absorb(layouter.namespace(|| format!("absorb {}", i)), value) + .unwrap(); + } + let outp = hasher + .sponge + .finish_absorbing(layouter.namespace(|| "finish absorbing")) + .unwrap() + .squeeze(layouter.namespace(|| "squeeze")) + .unwrap(); + let outp = Rc::new(outp); + + Ok(vec![outp]) + } +} diff --git a/mnist_zkml/src/gadgets.rs b/mnist_zkml/src/gadgets.rs new file mode 100644 index 0000000..19c10ea --- /dev/null +++ b/mnist_zkml/src/gadgets.rs @@ -0,0 +1,20 @@ +pub mod add_pairs; +pub mod adder; +pub mod bias_div_floor_relu6; +pub mod bias_div_round_relu6; +pub mod dot_prod; +pub mod gadget; +pub mod input_lookup; +pub mod max; +pub mod mul_pairs; +pub mod sqrt_big; +pub mod square; +pub mod squared_diff; +pub mod sub_pairs; +pub mod update; +pub mod var_div; +pub mod var_div_big; +pub mod var_div_big3; + +// Generics +pub mod nonlinear; diff --git a/mnist_zkml/src/gadgets/add_pairs.rs b/mnist_zkml/src/gadgets/add_pairs.rs new file mode 100644 index 0000000..530e74b --- /dev/null +++ b/mnist_zkml/src/gadgets/add_pairs.rs @@ -0,0 +1,135 @@ +use std::{marker::PhantomData, rc::Rc}; + +use halo2_proofs::{ + circuit::{AssignedCell, Layouter, Region}, + halo2curves::ff::PrimeField, + plonk::{ConstraintSystem, Error}, + poly::Rotation, +}; + +use super::gadget::{Gadget, GadgetConfig, GadgetType}; + +type AddPairsConfig = GadgetConfig; + +pub struct AddPairsChip { + config: Rc, + _marker: PhantomData, +} + +impl AddPairsChip { + pub fn construct(config: Rc) -> Self { + Self { + config, + _marker: PhantomData, + } + } + + pub fn num_cols_per_op() -> usize { + 3 + } + + pub fn configure(meta: &mut ConstraintSystem, gadget_config: GadgetConfig) -> GadgetConfig { + let selector = meta.selector(); + let columns = gadget_config.columns; + + meta.create_gate("add pair", |meta| { + let s = meta.query_selector(selector); + let mut constraints = vec![]; + for i in 0..columns.len() / Self::num_cols_per_op() { + let offset = i * Self::num_cols_per_op(); + let inp1 = meta.query_advice(columns[offset + 0], Rotation::cur()); + let inp2 = meta.query_advice(columns[offset + 1], Rotation::cur()); + let outp = meta.query_advice(columns[offset + 2], Rotation::cur()); + + let res = inp1 + inp2; + constraints.append(&mut vec![s.clone() * (res - outp)]) + } + + constraints + }); + + let mut selectors = gadget_config.selectors; + selectors.insert(GadgetType::AddPairs, vec![selector]); + + GadgetConfig { + columns, + selectors, + ..gadget_config + } + } +} + +impl Gadget for AddPairsChip { + fn name(&self) -> String { + "add pairs chip".to_string() + } + + fn num_cols_per_op(&self) -> usize { + Self::num_cols_per_op() + } + + fn num_inputs_per_row(&self) -> usize { + self.config.columns.len() / self.num_cols_per_op() + } + + fn num_outputs_per_row(&self) -> usize { + self.config.columns.len() / self.num_cols_per_op() + } + + fn op_row_region( + &self, + region: &mut Region, + row_offset: usize, + vec_inputs: &Vec>>, + _single_inputs: &Vec<&AssignedCell>, + ) -> Result>, Error> { + let inp1 = &vec_inputs[0]; + let inp2 = &vec_inputs[1]; + assert_eq!(inp1.len(), inp2.len()); + + let columns = &self.config.columns; + + if self.config.use_selectors { + let selector = self.config.selectors.get(&GadgetType::AddPairs).unwrap()[0]; + selector.enable(region, row_offset)?; + } + + let mut outps = vec![]; + for i in 0..inp1.len() { + let offset = i * self.num_cols_per_op(); + let inp1 = inp1[i].copy_advice(|| "", region, columns[offset + 0], row_offset)?; + let inp2 = inp2[i].copy_advice(|| "", region, columns[offset + 1], row_offset)?; + let outp = inp1.value().map(|x: &F| x.to_owned()) + inp2.value().map(|x: &F| x.to_owned()); + + let outp = region.assign_advice(|| "", columns[offset + 2], row_offset, || outp)?; + outps.push(outp); + } + Ok(outps) + } + + fn forward( + &self, + mut layouter: impl Layouter, + vec_inputs: &Vec>>, + single_inputs: &Vec<&AssignedCell>, + ) -> Result>, Error> { + let zero = &single_inputs[0]; + + let mut inp1 = vec_inputs[0].clone(); + let mut inp2 = vec_inputs[1].clone(); + let initial_len = inp1.len(); + while inp1.len() % self.num_inputs_per_row() != 0 { + inp1.push(zero); + inp2.push(zero); + } + + let vec_inputs = vec![inp1, inp2]; + + let res = self.op_aligned_rows( + layouter.namespace(|| format!("forward row {}", self.name())), + &vec_inputs, + single_inputs, + )?; + Ok(res[0..initial_len].to_vec()) + } +} diff --git a/mnist_zkml/src/gadgets/adder.rs b/mnist_zkml/src/gadgets/adder.rs new file mode 100644 index 0000000..06595d6 --- /dev/null +++ b/mnist_zkml/src/gadgets/adder.rs @@ -0,0 +1,143 @@ +use std::{marker::PhantomData, rc::Rc}; + +use halo2_proofs::{ + circuit::{AssignedCell, Layouter, Region, Value}, + halo2curves::ff::PrimeField, + plonk::{ConstraintSystem, Error, Expression}, + poly::Rotation, +}; + +use super::gadget::{Gadget, GadgetConfig, GadgetType}; + +type AdderConfig = GadgetConfig; + +pub struct AdderChip { + config: Rc, + _marker: PhantomData, +} + +impl AdderChip { + pub fn construct(config: Rc) -> Self { + Self { + config, + _marker: PhantomData, + } + } + + pub fn configure(meta: &mut ConstraintSystem, gadget_config: GadgetConfig) -> GadgetConfig { + let selector = meta.selector(); + let columns = gadget_config.columns; + + meta.create_gate("adder gate", |meta| { + let s = meta.query_selector(selector); + let gate_inp = columns[0..columns.len() - 1] + .iter() + .map(|col| meta.query_advice(*col, Rotation::cur())) + .collect::>(); + let gate_output = meta.query_advice(*columns.last().unwrap(), Rotation::cur()); + + let res = gate_inp + .iter() + .fold(Expression::Constant(F::ZERO), |a, b| a + b.clone()); + + vec![s * (res - gate_output)] + }); + + let mut selectors = gadget_config.selectors; + selectors.insert(GadgetType::Adder, vec![selector]); + + GadgetConfig { + columns, + selectors, + ..gadget_config + } + } +} + +// NOTE: The forward pass of the adder adds _everything_ into one cell +impl Gadget for AdderChip { + fn name(&self) -> String { + "adder".to_string() + } + + fn num_cols_per_op(&self) -> usize { + self.config.columns.len() + } + + fn num_inputs_per_row(&self) -> usize { + self.config.columns.len() - 1 + } + + fn num_outputs_per_row(&self) -> usize { + 1 + } + + fn op_row_region( + &self, + region: &mut Region, + row_offset: usize, + vec_inputs: &Vec>>, + _single_inputs: &Vec<&AssignedCell>, + ) -> Result>, Error> { + assert_eq!(vec_inputs.len(), 1); + let inp = &vec_inputs[0]; + + if self.config.use_selectors { + let selector = self.config.selectors.get(&GadgetType::Adder).unwrap()[0]; + selector.enable(region, row_offset)?; + } + + inp + .iter() + .enumerate() + .map(|(i, cell)| cell.copy_advice(|| "", region, self.config.columns[i], row_offset)) + .collect::, _>>()?; + + let e = inp.iter().fold(Value::known(F::ZERO), |a, b| { + a + b.value().map(|x: &F| x.to_owned()) + }); + let res = region.assign_advice( + || "", + *self.config.columns.last().unwrap(), + row_offset, + || e, + )?; + + Ok(vec![res]) + } + + fn forward( + &self, + mut layouter: impl Layouter, + vec_inputs: &Vec>>, + single_inputs: &Vec<&AssignedCell>, + ) -> Result>, Error> { + assert_eq!(single_inputs.len(), 1); + + let mut inputs = vec_inputs[0].clone(); + let zero = single_inputs[0].clone(); + + while inputs.len() % self.num_inputs_per_row() != 0 { + inputs.push(&zero); + } + + let mut outputs = self.op_aligned_rows( + layouter.namespace(|| "adder forward"), + &vec![inputs], + single_inputs, + )?; + while outputs.len() != 1 { + while outputs.len() % self.num_inputs_per_row() != 0 { + outputs.push(zero.clone()); + } + let tmp = outputs.iter().map(|x| x).collect::>(); + outputs = self.op_aligned_rows( + layouter.namespace(|| "adder forward"), + &vec![tmp], + single_inputs, + )?; + } + + Ok(outputs) + } +} diff --git a/mnist_zkml/src/gadgets/bias_div_floor_relu6.rs b/mnist_zkml/src/gadgets/bias_div_floor_relu6.rs new file mode 100644 index 0000000..bccd5eb --- /dev/null +++ b/mnist_zkml/src/gadgets/bias_div_floor_relu6.rs @@ -0,0 +1,274 @@ +use std::{collections::HashMap, marker::PhantomData}; + +use halo2_proofs::{ + circuit::{AssignedCell, Layouter, Region}, + halo2curves::ff::PrimeField, + plonk::{ConstraintSystem, Error, Expression}, + poly::Rotation, +}; + +use crate::gadgets::gadget::convert_to_u64; + +use super::gadget::{Gadget, GadgetConfig, GadgetType}; + +type BiasDivFloorRelu6Config = GadgetConfig; + +const SHIFT_MIN_VAL: i64 = -(1 << 30); + +pub struct BiasDivFloorRelu6Chip { + config: BiasDivFloorRelu6Config, + _marker: PhantomData, +} + +impl BiasDivFloorRelu6Chip { + pub fn construct(config: BiasDivFloorRelu6Config) -> Self { + Self { + config, + _marker: PhantomData, + } + } + + pub fn get_map(scale_factor: u64, num_rows: i64, div_outp_min_val: i64) -> HashMap { + let div_val = scale_factor; + let div_outp_min_val = div_outp_min_val; + + let mut map = HashMap::new(); + for i in 0..num_rows { + let shifted = i + div_outp_min_val; + let val = shifted.clamp(0, 6 * div_val as i64); + map.insert(i as i64, val); + } + map + } + + pub fn num_cols_per_op() -> usize { + 5 + } + + pub fn configure(meta: &mut ConstraintSystem, gadget_config: GadgetConfig) -> GadgetConfig { + let selector = meta.complex_selector(); + let sf = Expression::Constant(F::from(gadget_config.scale_factor)); + let columns = gadget_config.columns; + + let mod_lookup = meta.lookup_table_column(); + let relu_lookup = meta.lookup_table_column(); + let div_lookup = meta.lookup_table_column(); + + meta.create_gate("bias_mul", |meta| { + let s = meta.query_selector(selector); + + let mut constraints = vec![]; + for op_idx in 0..columns.len() / Self::num_cols_per_op() { + let offset = op_idx * Self::num_cols_per_op(); + let inp = meta.query_advice(columns[offset + 0], Rotation::cur()); + let bias = meta.query_advice(columns[offset + 1], Rotation::cur()); + let div_res = meta.query_advice(columns[offset + 2], Rotation::cur()); + let mod_res = meta.query_advice(columns[offset + 3], Rotation::cur()); + + constraints.push(s.clone() * (inp - (sf.clone() * (div_res - bias) + mod_res))); + } + + constraints + }); + + for op_idx in 0..columns.len() / Self::num_cols_per_op() { + let offset = op_idx * Self::num_cols_per_op(); + meta.lookup("bias_div_relu6 lookup", |meta| { + let s = meta.query_selector(selector); + let mod_res = meta.query_advice(columns[offset + 3], Rotation::cur()); + + // Constrains that the modulus \in [0, DIV_VAL) + vec![(s.clone() * mod_res.clone(), mod_lookup)] + }); + meta.lookup("bias_div_relu6 lookup", |meta| { + let s = meta.query_selector(selector); + let div = meta.query_advice(columns[offset + 2], Rotation::cur()); + let outp = meta.query_advice(columns[offset + 4], Rotation::cur()); + let div_outp_min_val = Expression::Constant(F::from((-SHIFT_MIN_VAL) as u64)); + + // Constrains that output \in [0, 6 * SF] + vec![ + (s.clone() * outp, relu_lookup), + (s * (div + div_outp_min_val), div_lookup), + ] + }); + } + + let mut selectors = gadget_config.selectors; + selectors.insert(GadgetType::BiasDivFloorRelu6, vec![selector]); + + let mut tables = gadget_config.tables; + tables.insert( + GadgetType::BiasDivFloorRelu6, + vec![mod_lookup, relu_lookup, div_lookup], + ); + + let mut maps = gadget_config.maps; + let relu_map = Self::get_map( + gadget_config.scale_factor, + gadget_config.num_rows as i64, + gadget_config.div_outp_min_val, + ); + maps.insert(GadgetType::BiasDivFloorRelu6, vec![relu_map]); + + GadgetConfig { + columns, + selectors, + tables, + maps, + ..gadget_config + } + } +} + +impl Gadget for BiasDivFloorRelu6Chip { + fn name(&self) -> String { + "BiasDivRelu6".to_string() + } + + fn num_cols_per_op(&self) -> usize { + Self::num_cols_per_op() + } + + fn num_inputs_per_row(&self) -> usize { + self.config.columns.len() / self.num_cols_per_op() + } + + fn num_outputs_per_row(&self) -> usize { + self.num_inputs_per_row() + } + + fn op_row_region( + &self, + region: &mut Region, + row_offset: usize, + vec_inputs: &Vec>>, + _single_inputs: &Vec<&AssignedCell>, + ) -> Result>, Error> { + let div_val = self.config.scale_factor as i64; + + let div_outp_min_val_i64 = -self.config.div_outp_min_val; + + let div_inp_min_val_pos_i64 = -SHIFT_MIN_VAL; + let div_inp_min_val_pos = F::from(div_inp_min_val_pos_i64 as u64); + + let inp = &vec_inputs[0]; + let bias = &vec_inputs[1]; + assert_eq!(inp.len(), bias.len()); + assert_eq!(inp.len() % self.num_inputs_per_row(), 0); + + let relu_map = &self + .config + .maps + .get(&GadgetType::BiasDivFloorRelu6) + .unwrap()[0]; + + if self.config.use_selectors { + let selector = self + .config + .selectors + .get(&GadgetType::BiasDivFloorRelu6) + .unwrap()[0]; + selector.enable(region, row_offset)?; + } + + let mut outp_cells = vec![]; + for (i, (inp, bias)) in inp.iter().zip(bias.iter()).enumerate() { + let offset = i * self.num_cols_per_op(); + + let inp_f = inp.value().map(|x: &F| x.to_owned()); + let bias_f = bias.value().map(|x: &F| { + let a = *x + div_inp_min_val_pos; + let a = convert_to_u64(&a) as i64 - div_inp_min_val_pos_i64; + a + }); + let div_mod_res = inp_f.map(|x: F| { + let x_pos = x + div_inp_min_val_pos; + let inp = convert_to_u64(&x_pos); + // println!("inp: {:?}, bias: {:?}, x_pos: {:?}", inp, bias, x_pos); + let div_res = inp as i64 / div_val - (div_inp_min_val_pos_i64 / div_val); + let mod_res = inp as i64 % div_val; + // println!("div_res: {:?}, mod_res: {:?}", div_res, mod_res); + (div_res, mod_res) + }); + let div_res = div_mod_res.map(|x: (i64, i64)| x.0) + bias_f; + let mod_res = div_mod_res.map(|x: (i64, i64)| x.1); + + let outp = div_res.map(|x: i64| { + let mut x_pos = x - div_outp_min_val_i64; + if !relu_map.contains_key(&(x_pos)) { + println!("x: {}, x_pos: {}", x, x_pos); + x_pos = 0; + } + let outp_val = relu_map.get(&(x_pos)).unwrap(); + // println!("x: {}, x_pos: {}, outp_val: {}", x, x_pos, outp_val); + F::from(*outp_val as u64) + }); + + // Assign inp, bias + inp.copy_advice(|| "", region, self.config.columns[offset + 0], row_offset)?; + bias.copy_advice(|| "", region, self.config.columns[offset + 1], row_offset)?; + + // Assign div_res, mod_res + let div_res_cell = region + .assign_advice( + || "div_res", + self.config.columns[offset + 2], + row_offset, + || { + div_res.map(|x: i64| { + F::from((x - div_outp_min_val_i64) as u64) - F::from(-div_outp_min_val_i64 as u64) + }) + }, + ) + .unwrap(); + let _mod_res_cell = region + .assign_advice( + || "mod_res", + self.config.columns[offset + 3], + row_offset, + || mod_res.map(|x: i64| F::from(x as u64)), + ) + .unwrap(); + + let outp_cell = region + .assign_advice( + || "outp", + self.config.columns[offset + 4], + row_offset, + || outp.map(|x: F| x.to_owned()), + ) + .unwrap(); + + // outp_cells.push((outp_cell, div_res_cell)); + outp_cells.push(outp_cell); + outp_cells.push(div_res_cell); + } + + Ok(outp_cells) + } + + fn forward( + &self, + mut layouter: impl Layouter, + vec_inputs: &Vec>>, + single_inputs: &Vec<&AssignedCell>, + ) -> Result>, Error> { + let mut inps = vec_inputs[0].clone(); + let mut biases = vec_inputs[1].clone(); + + // Needed to pad: bias - bias = 0 + let default = biases[0].clone(); + while inps.len() % self.num_inputs_per_row() != 0 { + inps.push(&default); + biases.push(&default); + } + + let res = self.op_aligned_rows( + layouter.namespace(|| "bias_div_relu6"), + &vec![inps, biases], + single_inputs, + )?; + Ok(res) + } +} diff --git a/mnist_zkml/src/gadgets/bias_div_round_relu6.rs b/mnist_zkml/src/gadgets/bias_div_round_relu6.rs new file mode 100644 index 0000000..6d8c125 --- /dev/null +++ b/mnist_zkml/src/gadgets/bias_div_round_relu6.rs @@ -0,0 +1,307 @@ +use std::{collections::HashMap, marker::PhantomData, rc::Rc}; + +use halo2_proofs::{ + circuit::{AssignedCell, Layouter, Region, Value}, + halo2curves::ff::PrimeField, + plonk::{ConstraintSystem, Error, Expression}, + poly::Rotation, +}; + +use crate::gadgets::gadget::convert_to_u64; + +use super::gadget::{Gadget, GadgetConfig, GadgetType}; + +type BiasDivRoundRelu6Config = GadgetConfig; + +const NUM_COLS_PER_OP: usize = 5; + +pub struct BiasDivRoundRelu6Chip { + config: Rc, + _marker: PhantomData, +} + +impl BiasDivRoundRelu6Chip { + pub fn construct(config: Rc) -> Self { + Self { + config, + _marker: PhantomData, + } + } + + pub fn get_map(scale_factor: u64, min_val: i64, num_rows: i64) -> HashMap { + let div_val = scale_factor; + + let mut map = HashMap::new(); + for i in 0..num_rows { + let shifted = i + min_val; + let val = shifted.clamp(0, 6 * div_val as i64); + map.insert(i as i64, val); + } + map + } + + pub fn configure(meta: &mut ConstraintSystem, gadget_config: GadgetConfig) -> GadgetConfig { + let selector = meta.complex_selector(); + let sf = Expression::Constant(F::from(gadget_config.scale_factor)); + let two = Expression::Constant(F::from(2)); + let columns = gadget_config.columns; + + let mut tables = gadget_config.tables; + let div_lookup = tables.get(&GadgetType::InputLookup).unwrap()[0]; + let relu_lookup = meta.lookup_table_column(); + + meta.create_gate("bias_mul", |meta| { + let s = meta.query_selector(selector); + + let mut constraints = vec![]; + for op_idx in 0..columns.len() / NUM_COLS_PER_OP { + let offset = op_idx * NUM_COLS_PER_OP; + let inp = meta.query_advice(columns[offset + 0], Rotation::cur()); + let bias = meta.query_advice(columns[offset + 1], Rotation::cur()); + let div_res = meta.query_advice(columns[offset + 2], Rotation::cur()); + let mod_res = meta.query_advice(columns[offset + 3], Rotation::cur()); + + // ((div - bias) * 2 + mod) * sf = 2 * inp + sf + constraints.push( + s.clone() + * (two.clone() * inp + sf.clone() + - (sf.clone() * two.clone() * (div_res - bias) + mod_res)), + ); + } + + constraints + }); + + for op_idx in 0..columns.len() / NUM_COLS_PER_OP { + let offset = op_idx * NUM_COLS_PER_OP; + meta.lookup("bias_div_relu6 lookup", |meta| { + let s = meta.query_selector(selector); + let mod_res = meta.query_advice(columns[offset + 3], Rotation::cur()); + + // Constrains that the modulus \in [0, DIV_VAL) + // div_val - mod_res \in [0, max_val) + vec![(s.clone() * (two.clone() * sf.clone() - mod_res), div_lookup)] + }); + meta.lookup("bias_div_relu6 lookup", |meta| { + let s = meta.query_selector(selector); + let div = meta.query_advice(columns[offset + 2], Rotation::cur()); + let outp = meta.query_advice(columns[offset + 4], Rotation::cur()); + let div_outp_min_val = gadget_config.div_outp_min_val; + let div_outp_min_val = Expression::Constant(F::from((-div_outp_min_val) as u64)); + + // Constrains that output \in [0, 6 * SF] + vec![ + (s.clone() * (div + div_outp_min_val), div_lookup), + (s.clone() * outp, relu_lookup), + ] + }); + } + + let mut selectors = gadget_config.selectors; + selectors.insert(GadgetType::BiasDivRoundRelu6, vec![selector]); + + tables.insert(GadgetType::BiasDivRoundRelu6, vec![relu_lookup]); + + let mut maps = gadget_config.maps; + let relu_map = Self::get_map( + gadget_config.scale_factor, + gadget_config.min_val, + gadget_config.num_rows as i64, + ); + maps.insert(GadgetType::BiasDivRoundRelu6, vec![relu_map]); + + GadgetConfig { + columns, + selectors, + tables, + maps, + ..gadget_config + } + } +} + +impl Gadget for BiasDivRoundRelu6Chip { + fn name(&self) -> String { + "BiasDivRelu6".to_string() + } + + fn num_cols_per_op(&self) -> usize { + NUM_COLS_PER_OP + } + + fn num_inputs_per_row(&self) -> usize { + self.config.columns.len() / NUM_COLS_PER_OP + } + + fn num_outputs_per_row(&self) -> usize { + self.num_inputs_per_row() * 2 + } + + fn load_lookups(&self, mut layouter: impl Layouter) -> Result<(), Error> { + let map = &self.config.maps[&GadgetType::BiasDivRoundRelu6][0]; + + let relu_lookup = self.config.tables[&GadgetType::BiasDivRoundRelu6][0]; + + layouter + .assign_table( + || "bdr round div/relu lookup", + |mut table| { + for i in 0..self.config.num_rows { + let i = i as i64; + let val = map.get(&i).unwrap(); + table + .assign_cell( + || "relu lookup", + relu_lookup, + i as usize, + || Value::known(F::from(*val as u64)), + ) + .unwrap(); + } + Ok(()) + }, + ) + .unwrap(); + + Ok(()) + } + + fn op_row_region( + &self, + region: &mut Region, + row_offset: usize, + vec_inputs: &Vec>>, + _single_inputs: &Vec<&AssignedCell>, + ) -> Result>, Error> { + let div_val = self.config.scale_factor as i64; + + let div_outp_min_val_i64 = self.config.div_outp_min_val; + + let div_inp_min_val_pos_i64 = -self.config.shift_min_val; + let div_inp_min_val_pos = F::from(div_inp_min_val_pos_i64 as u64); + + let inp = &vec_inputs[0]; + let bias = &vec_inputs[1]; + assert_eq!(inp.len(), bias.len()); + assert_eq!(inp.len() % self.num_inputs_per_row(), 0); + + let relu_map = &self + .config + .maps + .get(&GadgetType::BiasDivRoundRelu6) + .unwrap()[0]; + + if self.config.use_selectors { + let selector = self + .config + .selectors + .get(&GadgetType::BiasDivRoundRelu6) + .unwrap()[0]; + selector.enable(region, row_offset).unwrap(); + } + + let mut outp_cells = vec![]; + for (i, (inp, bias)) in inp.iter().zip(bias.iter()).enumerate() { + let offset = i * NUM_COLS_PER_OP; + + let inp_f = inp.value().map(|x: &F| x.to_owned()); + let bias_f = bias.value().map(|x: &F| { + let a = *x + div_inp_min_val_pos; + let a = convert_to_u64(&a) as i64 - div_inp_min_val_pos_i64; + a + }); + let div_mod_res = inp_f.map(|x: F| { + let x_pos = x + div_inp_min_val_pos; + let inp = convert_to_u64(&x_pos) as i64; + let div_inp = 2 * inp + div_val; + let div_res = div_inp / (2 * div_val) - div_inp_min_val_pos_i64 / div_val; + let mod_res = div_inp % (2 * div_val); + (div_res, mod_res) + }); + let div_res = div_mod_res.map(|x: (i64, i64)| x.0) + bias_f; + let mod_res = div_mod_res.map(|x: (i64, i64)| x.1); + + let outp = div_res.map(|x: i64| { + let mut x_pos = x - div_outp_min_val_i64; + if !relu_map.contains_key(&(x_pos)) { + println!("x: {}, x_pos: {}", x, x_pos); + x_pos = 0; + } + let outp_val = relu_map.get(&(x_pos)).unwrap(); + F::from(*outp_val as u64) + }); + + // Assign inp, bias + inp + .copy_advice(|| "", region, self.config.columns[offset + 0], row_offset) + .unwrap(); + bias + .copy_advice(|| "", region, self.config.columns[offset + 1], row_offset) + .unwrap(); + + // Assign div_res, mod_res + let div_res_cell = region + .assign_advice( + || "div_res", + self.config.columns[offset + 2], + row_offset, + || { + div_res.map(|x: i64| { + F::from((x - div_outp_min_val_i64) as u64) - F::from(-div_outp_min_val_i64 as u64) + }) + }, + ) + .unwrap(); + let _mod_res_cell = region + .assign_advice( + || "mod_res", + self.config.columns[offset + 3], + row_offset, + || mod_res.map(|x: i64| F::from(x as u64)), + ) + .unwrap(); + + let outp_cell = region + .assign_advice( + || "outp", + self.config.columns[offset + 4], + row_offset, + || outp.map(|x: F| x.to_owned()), + ) + .unwrap(); + + // outp_cells.push((outp_cell, div_res_cell)); + outp_cells.push(outp_cell); + outp_cells.push(div_res_cell); + } + + Ok(outp_cells) + } + + fn forward( + &self, + mut layouter: impl Layouter, + vec_inputs: &Vec>>, + single_inputs: &Vec<&AssignedCell>, + ) -> Result>, Error> { + let mut inps = vec_inputs[0].clone(); + let mut biases = vec_inputs[1].clone(); + let initial_len = inps.len(); + + // Needed to pad: bias - bias = 0 + let default = biases[0].clone(); + while inps.len() % self.num_inputs_per_row() != 0 { + inps.push(&default); + biases.push(&default); + } + + let res = self + .op_aligned_rows( + layouter.namespace(|| "bias_div_relu6"), + &vec![inps, biases], + single_inputs, + ) + .unwrap(); + Ok(res[0..initial_len * 2].to_vec()) + } +} diff --git a/mnist_zkml/src/gadgets/dot_prod.rs b/mnist_zkml/src/gadgets/dot_prod.rs new file mode 100644 index 0000000..54aea1e --- /dev/null +++ b/mnist_zkml/src/gadgets/dot_prod.rs @@ -0,0 +1,210 @@ +use std::{marker::PhantomData, rc::Rc}; + +use halo2_proofs::{ + circuit::{AssignedCell, Layouter, Region}, + halo2curves::ff::PrimeField, + plonk::{Advice, Column, ConstraintSystem, Error, Expression}, + poly::Rotation, +}; + +use crate::gadgets::adder::AdderChip; + +use super::gadget::{Gadget, GadgetConfig, GadgetType}; + +type DotProductConfig = GadgetConfig; + +pub struct DotProductChip { + config: Rc, + _marker: PhantomData, +} + +impl DotProductChip { + pub fn construct(config: Rc) -> Self { + Self { + config, + _marker: PhantomData, + } + } + + pub fn get_input_columns(config: &GadgetConfig) -> Vec> { + let num_inputs = (config.columns.len() - 1) / 2; + config.columns[0..num_inputs].to_vec() + } + + pub fn get_weight_columns(config: &GadgetConfig) -> Vec> { + let num_inputs = (config.columns.len() - 1) / 2; + config.columns[num_inputs..config.columns.len() - 1].to_vec() + } + + pub fn configure(meta: &mut ConstraintSystem, gadget_config: GadgetConfig) -> GadgetConfig { + let selector = meta.selector(); + let columns = &gadget_config.columns; + + meta.create_gate("dot product gate", |meta| { + let s = meta.query_selector(selector); + let gate_inp = DotProductChip::::get_input_columns(&gadget_config) + .iter() + .map(|col| meta.query_advice(*col, Rotation::cur())) + .collect::>(); + let gate_weights = DotProductChip::::get_weight_columns(&gadget_config) + .iter() + .map(|col| meta.query_advice(*col, Rotation::cur())) + .collect::>(); + let gate_output = meta.query_advice(columns[columns.len() - 1], Rotation::cur()); + + let res = gate_inp + .iter() + .zip(gate_weights) + .map(|(a, b)| a.clone() * b.clone()) + .fold(Expression::Constant(F::ZERO), |a, b| a + b); + + vec![s * (res - gate_output)] + }); + + let mut selectors = gadget_config.selectors; + selectors.insert(GadgetType::DotProduct, vec![selector]); + + GadgetConfig { + columns: gadget_config.columns, + selectors, + ..gadget_config + } + } +} + +impl Gadget for DotProductChip { + fn name(&self) -> String { + "dot product".to_string() + } + + fn num_cols_per_op(&self) -> usize { + self.config.columns.len() + } + + fn num_inputs_per_row(&self) -> usize { + (self.config.columns.len() - 1) / 2 + } + + fn num_outputs_per_row(&self) -> usize { + 1 + } + + // The caller is expected to pad the inputs + fn op_row_region( + &self, + region: &mut Region, + row_offset: usize, + vec_inputs: &Vec>>, + single_inputs: &Vec<&AssignedCell>, + ) -> Result>, Error> { + assert_eq!(vec_inputs.len(), 2); + + let inp = &vec_inputs[0]; + let weights = &vec_inputs[1]; + assert_eq!(inp.len(), weights.len()); + assert_eq!(inp.len(), self.num_inputs_per_row()); + + let zero = &single_inputs[0]; + + if self.config.use_selectors { + let selector = self.config.selectors.get(&GadgetType::DotProduct).unwrap()[0]; + selector.enable(region, row_offset).unwrap(); + } + + let inp_cols = DotProductChip::::get_input_columns(&self.config); + inp + .iter() + .enumerate() + .map(|(i, cell)| cell.copy_advice(|| "", region, inp_cols[i], row_offset)) + .collect::, _>>() + .unwrap(); + + let weight_cols = DotProductChip::::get_weight_columns(&self.config); + weights + .iter() + .enumerate() + .map(|(i, cell)| cell.copy_advice(|| "", region, weight_cols[i], row_offset)) + .collect::, _>>() + .unwrap(); + + // All columns need to be assigned + if self.config.columns.len() % 2 == 0 { + zero + .copy_advice( + || "", + region, + self.config.columns[self.config.columns.len() - 2], + row_offset, + ) + .unwrap(); + } + + let e = inp + .iter() + .zip(weights.iter()) + .map(|(a, b)| a.value().map(|x: &F| *x) * b.value()) + .reduce(|a, b| a + b) + .unwrap(); + + let res = region + .assign_advice( + || "", + self.config.columns[self.config.columns.len() - 1], + row_offset, + || e, + ) + .unwrap(); + + Ok(vec![res]) + } + + fn forward( + &self, + mut layouter: impl Layouter, + vec_inputs: &Vec>>, + single_inputs: &Vec<&AssignedCell>, + ) -> Result>, Error> { + assert_eq!(vec_inputs.len(), 2); + assert_eq!(single_inputs.len(), 1); + let zero = &single_inputs[0]; + + let mut inputs = vec_inputs[0].clone(); + let mut weights = vec_inputs[1].clone(); + while inputs.len() % self.num_inputs_per_row() != 0 { + inputs.push(&zero); + weights.push(&zero); + } + + let outputs = layouter + .assign_region( + || "dot prod rows", + |mut region| { + let mut outputs = vec![]; + for i in 0..inputs.len() / self.num_inputs_per_row() { + let inp = + inputs[i * self.num_inputs_per_row()..(i + 1) * self.num_inputs_per_row()].to_vec(); + let weights = + weights[i * self.num_inputs_per_row()..(i + 1) * self.num_inputs_per_row()].to_vec(); + let res = self + .op_row_region(&mut region, i, &vec![inp, weights], &vec![zero.clone()]) + .unwrap(); + outputs.push(res[0].clone()); + } + Ok(outputs) + }, + ) + .unwrap(); + + let adder_chip = AdderChip::::construct(self.config.clone()); + let tmp = outputs.iter().map(|x| x).collect::>(); + Ok( + adder_chip + .forward( + layouter.namespace(|| "dot prod adder"), + &vec![tmp], + single_inputs, + ) + .unwrap(), + ) + } +} diff --git a/mnist_zkml/src/gadgets/gadget.rs b/mnist_zkml/src/gadgets/gadget.rs new file mode 100644 index 0000000..182f957 --- /dev/null +++ b/mnist_zkml/src/gadgets/gadget.rs @@ -0,0 +1,153 @@ +use std::{ + collections::{BTreeSet, HashMap}, + sync::Arc, +}; + +use halo2_proofs::{ + circuit::{AssignedCell, Layouter, Region}, + halo2curves::group::ff::PrimeField, + plonk::{Advice, Column, Error, Fixed, Selector, TableColumn}, +}; +use num_bigint::{BigUint, ToBigUint}; +use num_traits::cast::ToPrimitive; + +#[derive(Clone, Copy, Debug, Hash, Eq, PartialEq, PartialOrd, Ord)] +pub enum GadgetType { + AddPairs, + Adder, + BiasDivRoundRelu6, + BiasDivFloorRelu6, + DotProduct, + Exp, + Logistic, + Max, + Pow, + Relu, + Rsqrt, + Sqrt, + SqrtBig, + Square, + SquaredDiff, + SubPairs, + Tanh, + MulPairs, + VarDivRound, + VarDivRoundBig, + VarDivRoundBig3, + Packer, // This is a special case + InputLookup, // Dummy placeholder for the input lookup + Update, +} + +#[derive(Clone, Debug, Default)] +pub struct GadgetConfig { + pub used_gadgets: Arc>, + pub columns: Vec>, + pub fixed_columns: Vec>, + pub selectors: HashMap>, + pub tables: HashMap>, + pub maps: HashMap>>, + pub scale_factor: u64, + pub shift_min_val: i64, // MUST be divisible by 2 * scale_factor + pub num_rows: usize, + pub num_cols: usize, + pub k: usize, + pub eta: f64, + pub min_val: i64, + pub max_val: i64, + pub div_outp_min_val: i64, + pub use_selectors: bool, + pub commit_before: Vec>, + pub commit_after: Vec>, + pub num_bits_per_elem: i64, +} + +// TODO: refactor +pub fn convert_to_u64(x: &F) -> u64 { + let big = BigUint::from_bytes_le(x.to_repr().as_ref()); + let big_digits = big.to_u64_digits(); + if big_digits.len() > 2 { + println!("big_digits: {:?}", big_digits); + } + if big_digits.len() == 1 { + big_digits[0] as u64 + } else if big_digits.len() == 0 { + 0 + } else { + panic!(); + } +} + +pub fn convert_to_u128(x: &F) -> u128 { + let big = BigUint::from_bytes_le(x.to_repr().as_ref()); + big.to_biguint().unwrap().to_u128().unwrap() +} + +pub trait Gadget { + fn name(&self) -> String; + + fn num_cols_per_op(&self) -> usize; + + fn num_inputs_per_row(&self) -> usize; + + fn num_outputs_per_row(&self) -> usize; + + fn load_lookups(&self, _layouter: impl Layouter) -> Result<(), Error> { + Ok(()) + } + + fn op_row_region( + &self, + region: &mut Region, + row_offset: usize, + vec_inputs: &Vec>>, + single_inputs: &Vec<&AssignedCell>, + ) -> Result>, Error>; + + // The caller is required to ensure that the inputs are of the correct length. + fn op_aligned_rows( + &self, + mut layouter: impl Layouter, + vec_inputs: &Vec>>, + single_inputs: &Vec<&AssignedCell>, + ) -> Result>, Error> { + // Sanity check inputs + for inp in vec_inputs.iter() { + assert_eq!(inp.len() % self.num_inputs_per_row(), 0); + } + + let outputs = layouter.assign_region( + || format!("gadget {}", self.name()), + |mut region| { + let mut outputs = vec![]; + for i in 0..vec_inputs[0].len() / self.num_inputs_per_row() { + let mut vec_inputs_row = vec![]; + for inp in vec_inputs.iter() { + vec_inputs_row.push( + inp[i * self.num_inputs_per_row()..(i + 1) * self.num_inputs_per_row()].to_vec(), + ); + } + let row_outputs = self.op_row_region(&mut region, i, &vec_inputs_row, &single_inputs)?; + assert_eq!(row_outputs.len(), self.num_outputs_per_row()); + outputs.extend(row_outputs); + } + Ok(outputs) + }, + )?; + + Ok(outputs) + } + + fn forward( + &self, + mut layouter: impl Layouter, + vec_inputs: &Vec>>, + single_inputs: &Vec<&AssignedCell>, + ) -> Result>, Error> { + self.op_aligned_rows( + layouter.namespace(|| format!("forward row {}", self.name())), + vec_inputs, + single_inputs, + ) + } +} diff --git a/mnist_zkml/src/gadgets/input_lookup.rs b/mnist_zkml/src/gadgets/input_lookup.rs new file mode 100644 index 0000000..19b5444 --- /dev/null +++ b/mnist_zkml/src/gadgets/input_lookup.rs @@ -0,0 +1,87 @@ +use std::{marker::PhantomData, rc::Rc}; + +use halo2_proofs::{ + circuit::{AssignedCell, Layouter, Region, Value}, + halo2curves::ff::PrimeField, + plonk::{ConstraintSystem, Error}, +}; + +use super::gadget::{Gadget, GadgetConfig, GadgetType}; + +pub struct InputLookupChip { + config: Rc, + _marker: PhantomData, +} + +impl InputLookupChip { + pub fn construct(config: Rc) -> Self { + Self { + config, + _marker: PhantomData, + } + } + + pub fn configure(meta: &mut ConstraintSystem, gadget_config: GadgetConfig) -> GadgetConfig { + let lookup = meta.lookup_table_column(); + let mut tables = gadget_config.tables; + tables.insert(GadgetType::InputLookup, vec![lookup]); + + GadgetConfig { + tables, + ..gadget_config + } + } +} + +impl Gadget for InputLookupChip { + fn load_lookups(&self, mut layouter: impl Layouter) -> Result<(), Error> { + let lookup = self.config.tables[&GadgetType::InputLookup][0]; + + layouter + .assign_table( + || "input lookup", + |mut table| { + for i in 0..self.config.num_rows as i64 { + table + .assign_cell( + || "mod lookup", + lookup, + i as usize, + || Value::known(F::from(i as u64)), + ) + .unwrap(); + } + Ok(()) + }, + ) + .unwrap(); + + Ok(()) + } + + fn name(&self) -> String { + panic!("InputLookupChip should not be called directly") + } + + fn num_cols_per_op(&self) -> usize { + panic!("InputLookupChip should not be called directly") + } + + fn num_inputs_per_row(&self) -> usize { + panic!("InputLookupChip should not be called directly") + } + + fn num_outputs_per_row(&self) -> usize { + panic!("InputLookupChip should not be called directly") + } + + fn op_row_region( + &self, + _region: &mut Region, + _row_offset: usize, + _vec_inputs: &Vec>>, + _single_inputs: &Vec<&AssignedCell>, + ) -> Result>, Error> { + panic!("InputLookupChip should not be called directly") + } +} diff --git a/mnist_zkml/src/gadgets/max.rs b/mnist_zkml/src/gadgets/max.rs new file mode 100644 index 0000000..48a62d0 --- /dev/null +++ b/mnist_zkml/src/gadgets/max.rs @@ -0,0 +1,185 @@ +use std::{marker::PhantomData, rc::Rc}; + +use halo2_proofs::{ + circuit::{AssignedCell, Layouter, Region}, + halo2curves::ff::PrimeField, + plonk::{ConstraintSystem, Error}, + poly::Rotation, +}; + +use crate::gadgets::gadget::convert_to_u64; + +use super::gadget::{Gadget, GadgetConfig, GadgetType}; + +pub struct MaxChip { + config: Rc, + _marker: PhantomData, +} + +impl MaxChip { + pub fn construct(config: Rc) -> Self { + Self { + config, + _marker: PhantomData, + } + } + + pub fn num_cols_per_op() -> usize { + 3 + } + + pub fn configure(meta: &mut ConstraintSystem, gadget_config: GadgetConfig) -> GadgetConfig { + let selector = meta.complex_selector(); + let columns = gadget_config.columns; + let tables = gadget_config.tables; + + let inp_lookup = tables.get(&GadgetType::InputLookup).unwrap()[0]; + + meta.create_gate("max arithmetic", |meta| { + let s = meta.query_selector(selector); + let mut constraints = vec![]; + for i in 0..columns.len() / Self::num_cols_per_op() { + let offset = i * Self::num_cols_per_op(); + let inp1 = meta.query_advice(columns[offset + 0], Rotation::cur()); + let inp2 = meta.query_advice(columns[offset + 1], Rotation::cur()); + let outp = meta.query_advice(columns[offset + 2], Rotation::cur()); + + constraints.push(s.clone() * (inp1 - outp.clone()) * (inp2 - outp)) + } + constraints + }); + + for idx in 0..columns.len() / Self::num_cols_per_op() { + meta.lookup("max inp1", |meta| { + let s = meta.query_selector(selector); + let offset = idx * Self::num_cols_per_op(); + let inp1 = meta.query_advice(columns[offset + 0], Rotation::cur()); + let outp = meta.query_advice(columns[offset + 2], Rotation::cur()); + + vec![(s * (outp - inp1), inp_lookup)] + }); + meta.lookup("max inp2", |meta| { + let s = meta.query_selector(selector); + let offset = idx * Self::num_cols_per_op(); + let inp2 = meta.query_advice(columns[offset + 1], Rotation::cur()); + let outp = meta.query_advice(columns[offset + 2], Rotation::cur()); + + vec![(s * (outp - inp2), inp_lookup)] + }); + } + + let mut selectors = gadget_config.selectors; + selectors.insert(GadgetType::Max, vec![selector]); + + GadgetConfig { + columns, + selectors, + tables, + ..gadget_config + } + } +} + +impl Gadget for MaxChip { + fn name(&self) -> String { + "max".to_string() + } + + fn num_cols_per_op(&self) -> usize { + 3 + } + + fn num_inputs_per_row(&self) -> usize { + self.config.columns.len() / self.num_cols_per_op() * 2 + } + + fn num_outputs_per_row(&self) -> usize { + self.config.columns.len() / self.num_cols_per_op() + } + + fn op_row_region( + &self, + region: &mut Region, + row_offset: usize, + vec_inputs: &Vec>>, + _single_inputs: &Vec<&AssignedCell>, + ) -> Result>, Error> { + assert_eq!(vec_inputs.len(), 1); + let inp = &vec_inputs[0]; + + if self.config.use_selectors { + let selector = self.config.selectors.get(&GadgetType::Max).unwrap()[0]; + selector.enable(region, row_offset)?; + } + + let min_val_pos = F::from((-self.config.shift_min_val) as u64); + + let mut outp = vec![]; + + let chunks: Vec<&[&AssignedCell]> = inp.chunks(self.num_outputs_per_row()).collect(); + let i1 = chunks[0]; + let i2 = chunks[1]; + for (idx, (inp1, inp2)) in i1.iter().zip(i2.iter()).enumerate() { + let offset = idx * self.num_cols_per_op(); + inp1 + .copy_advice(|| "", region, self.config.columns[offset + 0], row_offset) + .unwrap(); + inp2 + .copy_advice(|| "", region, self.config.columns[offset + 1], row_offset) + .unwrap(); + + let max = inp1.value().zip(inp2.value()).map(|(a, b)| { + let a = convert_to_u64(&(*a + min_val_pos)); + let b = convert_to_u64(&(*b + min_val_pos)); + let max = a.max(b); + let max = F::from(max) - min_val_pos; + max + }); + + let res = region + .assign_advice(|| "", self.config.columns[offset + 2], row_offset, || max) + .unwrap(); + outp.push(res); + } + + Ok(outp) + } + + fn forward( + &self, + mut layouter: impl Layouter, + vec_inputs: &Vec>>, + single_inputs: &Vec<&AssignedCell>, + ) -> Result>, Error> { + let mut inputs = vec_inputs[0].clone(); + let first = inputs[0]; + + while inputs.len() % self.num_inputs_per_row() != 0 { + inputs.push(first); + } + + // TODO: pretty sure this is correct but check + let num_iters = inputs.len().div_ceil(self.num_inputs_per_row()) + self.num_inputs_per_row(); + + let mut outputs = self.op_aligned_rows( + layouter.namespace(|| "max forward"), + &vec![inputs], + single_inputs, + )?; + for _ in 0..num_iters { + while outputs.len() % self.num_inputs_per_row() != 0 { + outputs.push(first.clone()); + } + let tmp = outputs.iter().map(|x| x).collect::>(); + outputs = self.op_aligned_rows( + layouter.namespace(|| "max forward"), + &vec![tmp], + single_inputs, + )?; + } + + outputs = vec![outputs.into_iter().next().unwrap()]; + + Ok(outputs) + } +} diff --git a/mnist_zkml/src/gadgets/mul_pairs.rs b/mnist_zkml/src/gadgets/mul_pairs.rs new file mode 100644 index 0000000..12e4c64 --- /dev/null +++ b/mnist_zkml/src/gadgets/mul_pairs.rs @@ -0,0 +1,136 @@ +use std::{marker::PhantomData, rc::Rc}; + +use halo2_proofs::{ + circuit::{AssignedCell, Layouter, Region}, + halo2curves::ff::PrimeField, + plonk::{ConstraintSystem, Error}, + poly::Rotation, +}; + +use super::gadget::{Gadget, GadgetConfig, GadgetType}; + +type MulPairsConfig = GadgetConfig; + +pub struct MulPairsChip { + config: Rc, + _marker: PhantomData, +} + +impl MulPairsChip { + pub fn construct(config: Rc) -> Self { + Self { + config, + _marker: PhantomData, + } + } + + pub fn num_cols_per_op() -> usize { + 3 + } + + pub fn configure(meta: &mut ConstraintSystem, gadget_config: GadgetConfig) -> GadgetConfig { + let selector = meta.selector(); + let columns = gadget_config.columns; + + meta.create_gate("mul pair", |meta| { + let s = meta.query_selector(selector); + let mut constraints = vec![]; + for i in 0..columns.len() / Self::num_cols_per_op() { + let offset = i * Self::num_cols_per_op(); + let inp1 = meta.query_advice(columns[offset + 0], Rotation::cur()); + let inp2 = meta.query_advice(columns[offset + 1], Rotation::cur()); + let outp = meta.query_advice(columns[offset + 2], Rotation::cur()); + + let res = inp1 * inp2; + constraints.append(&mut vec![s.clone() * (res - outp)]) + } + + constraints + }); + + let mut selectors = gadget_config.selectors; + selectors.insert(GadgetType::MulPairs, vec![selector]); + + GadgetConfig { + columns, + selectors, + ..gadget_config + } + } +} + +impl Gadget for MulPairsChip { + fn name(&self) -> String { + "MulPairs".to_string() + } + + fn num_cols_per_op(&self) -> usize { + Self::num_cols_per_op() + } + + fn num_inputs_per_row(&self) -> usize { + self.config.columns.len() / self.num_cols_per_op() + } + + fn num_outputs_per_row(&self) -> usize { + self.config.columns.len() / self.num_cols_per_op() + } + + // TODO: This + below is basically copied from add pairs - make arithmetic generic + fn op_row_region( + &self, + region: &mut Region, + row_offset: usize, + vec_inputs: &Vec>>, + _single_inputs: &Vec<&AssignedCell>, + ) -> Result>, Error> { + let inp1 = &vec_inputs[0]; + let inp2 = &vec_inputs[1]; + assert_eq!(inp1.len(), inp2.len()); + + let columns = &self.config.columns; + + if self.config.use_selectors { + let selector = self.config.selectors.get(&GadgetType::MulPairs).unwrap()[0]; + selector.enable(region, row_offset)?; + } + + let mut outps = vec![]; + for i in 0..inp1.len() { + let offset = i * self.num_cols_per_op(); + let inp1 = inp1[i].copy_advice(|| "", region, columns[offset + 0], row_offset)?; + let inp2 = inp2[i].copy_advice(|| "", region, columns[offset + 1], row_offset)?; + let outp = inp1.value().map(|x: &F| x.to_owned()) * inp2.value().map(|x: &F| x.to_owned()); + + let outp = region.assign_advice(|| "", columns[offset + 2], row_offset, || outp)?; + outps.push(outp); + } + Ok(outps) + } + + fn forward( + &self, + mut layouter: impl Layouter, + vec_inputs: &Vec>>, + single_inputs: &Vec<&AssignedCell>, + ) -> Result>, Error> { + let zero = &single_inputs[0]; + + let mut inp1 = vec_inputs[0].clone(); + let mut inp2 = vec_inputs[1].clone(); + let initial_len = inp1.len(); + while inp1.len() % self.num_inputs_per_row() != 0 { + inp1.push(zero); + inp2.push(zero); + } + + let vec_inputs = vec![inp1, inp2]; + + let res = self.op_aligned_rows( + layouter.namespace(|| format!("forward row {}", self.name())), + &vec_inputs, + single_inputs, + )?; + Ok(res[0..initial_len].to_vec()) + } +} diff --git a/mnist_zkml/src/gadgets/nonlinear.rs b/mnist_zkml/src/gadgets/nonlinear.rs new file mode 100644 index 0000000..26951fb --- /dev/null +++ b/mnist_zkml/src/gadgets/nonlinear.rs @@ -0,0 +1,8 @@ +pub mod exp; +pub mod logistic; +pub mod non_linearity; +pub mod pow; +pub mod relu; +pub mod rsqrt; +pub mod sqrt; +pub mod tanh; diff --git a/mnist_zkml/src/gadgets/nonlinear/exp.rs b/mnist_zkml/src/gadgets/nonlinear/exp.rs new file mode 100644 index 0000000..31c4c65 --- /dev/null +++ b/mnist_zkml/src/gadgets/nonlinear/exp.rs @@ -0,0 +1,104 @@ +use std::{collections::HashMap, marker::PhantomData, rc::Rc}; + +use halo2_proofs::{ + circuit::{AssignedCell, Layouter, Region}, + halo2curves::ff::PrimeField, + plonk::{ConstraintSystem, Error}, +}; + +use super::{ + super::gadget::{Gadget, GadgetConfig, GadgetType}, + non_linearity::NonLinearGadget, +}; + +type ExpGadgetConfig = GadgetConfig; + +// IMPORTANT: this return exp(x) * SF +pub struct ExpGadgetChip { + config: Rc, + _marker: PhantomData, +} + +impl ExpGadgetChip { + pub fn construct(config: Rc) -> Self { + Self { + config, + _marker: PhantomData, + } + } + + pub fn configure(meta: &mut ConstraintSystem, gadget_config: GadgetConfig) -> GadgetConfig { + as NonLinearGadget>::configure(meta, gadget_config, GadgetType::Exp) + } +} + +impl NonLinearGadget for ExpGadgetChip { + fn generate_map(scale_factor: u64, min_val: i64, num_rows: i64) -> HashMap { + let mut map = HashMap::new(); + for i in 0..num_rows { + let shifted = i + min_val; + let x = (shifted as f64) / (scale_factor as f64); + let exp = x.exp(); + let exp = (exp * ((scale_factor * scale_factor) as f64)).round() as i64; + map.insert(i as i64, exp); + } + map + } + + fn get_map(&self) -> &HashMap { + &self.config.maps.get(&GadgetType::Exp).unwrap()[0] + } + + fn get_selector(&self) -> halo2_proofs::plonk::Selector { + self.config.selectors.get(&GadgetType::Exp).unwrap()[0] + } +} + +impl Gadget for ExpGadgetChip { + fn name(&self) -> String { + "Exp".to_string() + } + + fn num_cols_per_op(&self) -> usize { + as NonLinearGadget>::num_cols_per_op() + } + + fn num_inputs_per_row(&self) -> usize { + self.config.columns.len() / self.num_cols_per_op() + } + + fn num_outputs_per_row(&self) -> usize { + self.config.columns.len() / self.num_cols_per_op() + } + + fn load_lookups(&self, layouter: impl Layouter) -> Result<(), Error> { + NonLinearGadget::load_lookups(self, layouter, self.config.clone(), GadgetType::Exp)?; + Ok(()) + } + + fn op_row_region( + &self, + region: &mut Region, + row_offset: usize, + vec_inputs: &Vec>>, + single_inputs: &Vec<&AssignedCell>, + ) -> Result>, Error> { + NonLinearGadget::op_row_region( + self, + region, + row_offset, + vec_inputs, + single_inputs, + self.config.clone(), + ) + } + + fn forward( + &self, + layouter: impl halo2_proofs::circuit::Layouter, + vec_inputs: &Vec>>, + single_inputs: &Vec<&AssignedCell>, + ) -> Result>, Error> { + NonLinearGadget::forward(self, layouter, vec_inputs, single_inputs) + } +} diff --git a/mnist_zkml/src/gadgets/nonlinear/logistic.rs b/mnist_zkml/src/gadgets/nonlinear/logistic.rs new file mode 100644 index 0000000..ed97f0e --- /dev/null +++ b/mnist_zkml/src/gadgets/nonlinear/logistic.rs @@ -0,0 +1,106 @@ +use std::{collections::HashMap, marker::PhantomData, rc::Rc}; + +use halo2_proofs::{ + circuit::{AssignedCell, Layouter, Region}, + halo2curves::ff::PrimeField, + plonk::{ConstraintSystem, Error}, +}; + +use super::{ + super::gadget::{Gadget, GadgetConfig, GadgetType}, + non_linearity::NonLinearGadget, +}; + +pub struct LogisticGadgetChip { + config: Rc, + _marker: PhantomData, +} + +impl LogisticGadgetChip { + pub fn construct(config: Rc) -> Self { + Self { + config, + _marker: PhantomData, + } + } + + pub fn configure(meta: &mut ConstraintSystem, gadget_config: GadgetConfig) -> GadgetConfig { + as NonLinearGadget>::configure( + meta, + gadget_config, + GadgetType::Logistic, + ) + } +} + +impl NonLinearGadget for LogisticGadgetChip { + fn generate_map(scale_factor: u64, min_val: i64, num_rows: i64) -> HashMap { + let mut map = HashMap::new(); + for i in 0..num_rows { + let shifted = i + min_val; + let x = (shifted as f64) / (scale_factor as f64); + let logistic = 1. / (1. + (-x).exp()); + let logistic = (logistic * ((scale_factor) as f64)).round() as i64; + map.insert(i as i64, logistic); + } + + map + } + + fn get_map(&self) -> &HashMap { + &self.config.maps.get(&GadgetType::Logistic).unwrap()[0] + } + + fn get_selector(&self) -> halo2_proofs::plonk::Selector { + self.config.selectors.get(&GadgetType::Logistic).unwrap()[0] + } +} + +impl Gadget for LogisticGadgetChip { + fn name(&self) -> String { + "LogisticChip".to_string() + } + + fn num_cols_per_op(&self) -> usize { + as NonLinearGadget>::num_cols_per_op() + } + + fn num_inputs_per_row(&self) -> usize { + self.config.columns.len() / self.num_cols_per_op() + } + + fn num_outputs_per_row(&self) -> usize { + self.config.columns.len() / self.num_cols_per_op() + } + + fn load_lookups(&self, layouter: impl Layouter) -> Result<(), Error> { + NonLinearGadget::load_lookups(self, layouter, self.config.clone(), GadgetType::Logistic)?; + Ok(()) + } + + fn op_row_region( + &self, + region: &mut Region, + row_offset: usize, + vec_inputs: &Vec>>, + single_inputs: &Vec<&AssignedCell>, + ) -> Result>, Error> { + NonLinearGadget::op_row_region( + self, + region, + row_offset, + vec_inputs, + single_inputs, + self.config.clone(), + ) + } + + fn forward( + &self, + layouter: impl halo2_proofs::circuit::Layouter, + vec_inputs: &Vec>>, + single_inputs: &Vec<&AssignedCell>, + ) -> Result>, Error> { + NonLinearGadget::forward(self, layouter, vec_inputs, single_inputs) + } +} diff --git a/mnist_zkml/src/gadgets/nonlinear/non_linearity.rs b/mnist_zkml/src/gadgets/nonlinear/non_linearity.rs new file mode 100644 index 0000000..f7a9eff --- /dev/null +++ b/mnist_zkml/src/gadgets/nonlinear/non_linearity.rs @@ -0,0 +1,190 @@ +use std::{collections::HashMap, rc::Rc}; + +use halo2_proofs::{ + circuit::{AssignedCell, Layouter, Region, Value}, + halo2curves::ff::PrimeField, + plonk::{ConstraintSystem, Error, Expression, Selector}, + poly::Rotation, +}; + +use crate::gadgets::gadget::convert_to_u128; + +use super::super::gadget::Gadget; +use super::super::gadget::{GadgetConfig, GadgetType}; + +const NUM_COLS_PER_OP: usize = 2; + +pub trait NonLinearGadget: Gadget { + fn generate_map(scale_factor: u64, min_val: i64, num_rows: i64) -> HashMap; + + fn get_map(&self) -> &HashMap; + + fn get_selector(&self) -> Selector; + + fn num_cols_per_op() -> usize { + NUM_COLS_PER_OP + } + + fn configure( + meta: &mut ConstraintSystem, + gadget_config: GadgetConfig, + gadget_type: GadgetType, + ) -> GadgetConfig { + let selector = meta.complex_selector(); + let columns = gadget_config.columns; + + let mut tables = gadget_config.tables; + let inp_lookup = tables.get(&GadgetType::InputLookup).unwrap()[0]; + let outp_lookup = meta.lookup_table_column(); + + for op_idx in 0..columns.len() / NUM_COLS_PER_OP { + let offset = op_idx * NUM_COLS_PER_OP; + meta.lookup("non-linear lookup", |meta| { + let s = meta.query_selector(selector); + let inp = meta.query_advice(columns[offset + 0], Rotation::cur()); + let outp = meta.query_advice(columns[offset + 1], Rotation::cur()); + let shift_val = gadget_config.min_val; + let shift_val_pos = Expression::Constant(F::from((-shift_val) as u64)); + + vec![ + (s.clone() * (inp + shift_val_pos), inp_lookup), + (s.clone() * outp, outp_lookup), + ] + }); + } + + let mut selectors = gadget_config.selectors; + selectors.insert(gadget_type, vec![selector]); + + tables.insert(gadget_type, vec![inp_lookup, outp_lookup]); + + let mut maps = gadget_config.maps; + let non_linear_map = Self::generate_map( + gadget_config.scale_factor, + gadget_config.min_val, + gadget_config.num_rows as i64, + ); + maps.insert(gadget_type, vec![non_linear_map]); + + GadgetConfig { + columns, + selectors, + tables, + maps, + ..gadget_config + } + } + + fn load_lookups( + &self, + mut layouter: impl Layouter, + config: Rc, + gadget_type: GadgetType, + ) -> Result<(), Error> { + let map = self.get_map(); + let table_col = config.tables.get(&gadget_type).unwrap()[1]; + + let shift_pos_i64 = -config.shift_min_val; + let shift_pos = F::from(shift_pos_i64 as u64); + layouter.assign_table( + || "non linear table", + |mut table| { + for i in 0..config.num_rows { + let i = i as i64; + // FIXME: refactor this + let tmp = *map.get(&i).unwrap(); + let val = if i == 0 { + F::ZERO + } else { + if tmp >= 0 { + F::from(tmp as u64) + } else { + let tmp = tmp + shift_pos_i64; + F::from(tmp as u64) - shift_pos + } + }; + table.assign_cell( + || "non linear cell", + table_col, + i as usize, + || Value::known(val), + )?; + } + Ok(()) + }, + )?; + Ok(()) + } + + fn op_row_region( + &self, + region: &mut Region, + row_offset: usize, + vec_inputs: &Vec>>, + _single_inputs: &Vec<&AssignedCell>, + gadget_config: Rc, + ) -> Result>, Error> { + let columns = &gadget_config.columns; + let inp = &vec_inputs[0]; + let map = self.get_map(); + let shift_val_pos_i64 = -gadget_config.shift_min_val; + let shift_val_pos = F::from(shift_val_pos_i64 as u64); + let min_val = gadget_config.min_val; + + if gadget_config.use_selectors { + let selector = self.get_selector(); + selector.enable(region, row_offset)?; + } + + let mut outps = vec![]; + for i in 0..inp.len() { + let offset = i * 2; + inp[i].copy_advice(|| "", region, columns[offset + 0], row_offset)?; + let outp = inp[i].value().map(|x: &F| { + let pos = convert_to_u128(&(*x + shift_val_pos)) as i128 - shift_val_pos_i64 as i128; + let x = pos as i64 - min_val; + let val = *map.get(&x).unwrap(); + if x == 0 { + F::ZERO + } else { + if val >= 0 { + F::from(val as u64) + } else { + let val_pos = val + shift_val_pos_i64; + F::from(val_pos as u64) - F::from(shift_val_pos_i64 as u64) + } + } + }); + + let outp = + region.assign_advice(|| "nonlinearity", columns[offset + 1], row_offset, || outp)?; + outps.push(outp); + } + + Ok(outps) + } + + fn forward( + &self, + mut layouter: impl Layouter, + vec_inputs: &Vec>>, + single_inputs: &Vec<&AssignedCell>, + ) -> Result>, Error> { + let zero = &single_inputs[0]; + let inp_len = vec_inputs[0].len(); + let mut inp = vec_inputs[0].clone(); + + while inp.len() % self.num_inputs_per_row() != 0 { + inp.push(zero); + } + + let vec_inputs = vec![inp]; + let outp = self.op_aligned_rows( + layouter.namespace(|| format!("forward row {}", self.name())), + &vec_inputs, + &single_inputs, + )?; + + Ok(outp[0..inp_len].to_vec()) + } +} diff --git a/mnist_zkml/src/gadgets/nonlinear/pow.rs b/mnist_zkml/src/gadgets/nonlinear/pow.rs new file mode 100644 index 0000000..c628ee6 --- /dev/null +++ b/mnist_zkml/src/gadgets/nonlinear/pow.rs @@ -0,0 +1,105 @@ +use std::{collections::HashMap, marker::PhantomData, rc::Rc}; + +use halo2_proofs::{ + circuit::{AssignedCell, Layouter, Region}, + halo2curves::ff::PrimeField, + plonk::{ConstraintSystem, Error}, +}; + +use super::{ + super::gadget::{Gadget, GadgetConfig, GadgetType}, + non_linearity::NonLinearGadget, +}; + +// IMPORTANT: PowGadget assumes a single power across the entire DAG +pub struct PowGadgetChip { + config: Rc, + _marker: PhantomData, +} + +impl PowGadgetChip { + pub fn construct(config: Rc) -> Self { + Self { + config, + _marker: PhantomData, + } + } + + pub fn configure(meta: &mut ConstraintSystem, gadget_config: GadgetConfig) -> GadgetConfig { + as NonLinearGadget>::configure(meta, gadget_config, GadgetType::Pow) + } +} + +impl NonLinearGadget for PowGadgetChip { + fn generate_map(scale_factor: u64, min_val: i64, num_rows: i64) -> HashMap { + let power = 3.; // FIXME: need to make this variable somehow... + + let mut map = HashMap::new(); + for i in 0..num_rows { + let shifted = i + min_val; + let x = (shifted as f64) / (scale_factor as f64); + let y = x.powf(power); + let y = (y * ((scale_factor) as f64)).round() as i64; + map.insert(i as i64, y); + } + + map + } + + fn get_map(&self) -> &HashMap { + &self.config.maps.get(&GadgetType::Pow).unwrap()[0] + } + + fn get_selector(&self) -> halo2_proofs::plonk::Selector { + self.config.selectors.get(&GadgetType::Pow).unwrap()[0] + } +} + +impl Gadget for PowGadgetChip { + fn name(&self) -> String { + "PowGadgetChip".to_string() + } + + fn num_cols_per_op(&self) -> usize { + as NonLinearGadget>::num_cols_per_op() + } + + fn num_inputs_per_row(&self) -> usize { + self.config.columns.len() / self.num_cols_per_op() + } + + fn num_outputs_per_row(&self) -> usize { + self.config.columns.len() / self.num_cols_per_op() + } + + fn load_lookups(&self, layouter: impl Layouter) -> Result<(), Error> { + NonLinearGadget::load_lookups(self, layouter, self.config.clone(), GadgetType::Pow)?; + Ok(()) + } + + fn op_row_region( + &self, + region: &mut Region, + row_offset: usize, + vec_inputs: &Vec>>, + single_inputs: &Vec<&AssignedCell>, + ) -> Result>, Error> { + NonLinearGadget::op_row_region( + self, + region, + row_offset, + vec_inputs, + single_inputs, + self.config.clone(), + ) + } + + fn forward( + &self, + layouter: impl halo2_proofs::circuit::Layouter, + vec_inputs: &Vec>>, + single_inputs: &Vec<&AssignedCell>, + ) -> Result>, Error> { + NonLinearGadget::forward(self, layouter, vec_inputs, single_inputs) + } +} diff --git a/mnist_zkml/src/gadgets/nonlinear/relu.rs b/mnist_zkml/src/gadgets/nonlinear/relu.rs new file mode 100644 index 0000000..d54ca8b --- /dev/null +++ b/mnist_zkml/src/gadgets/nonlinear/relu.rs @@ -0,0 +1,100 @@ +use std::{collections::HashMap, marker::PhantomData, rc::Rc}; + +use halo2_proofs::{ + circuit::{AssignedCell, Layouter, Region}, + halo2curves::ff::PrimeField, + plonk::{ConstraintSystem, Error}, +}; + +use super::{ + super::gadget::{Gadget, GadgetConfig, GadgetType}, + non_linearity::NonLinearGadget, +}; + +pub struct ReluChip { + config: Rc, + _marker: PhantomData, +} + +impl ReluChip { + pub fn construct(config: Rc) -> Self { + Self { + config, + _marker: PhantomData, + } + } + + pub fn configure(meta: &mut ConstraintSystem, gadget_config: GadgetConfig) -> GadgetConfig { + as NonLinearGadget>::configure(meta, gadget_config, GadgetType::Relu) + } +} + +impl NonLinearGadget for ReluChip { + fn generate_map(_scale_factor: u64, min_val: i64, num_rows: i64) -> HashMap { + let mut map = HashMap::new(); + for i in 0..num_rows { + let shifted = i + min_val; + let relu = shifted.max(0); + map.insert(i as i64, relu); + } + + map + } + + fn get_map(&self) -> &HashMap { + &self.config.maps.get(&GadgetType::Relu).unwrap()[0] + } + + fn get_selector(&self) -> halo2_proofs::plonk::Selector { + self.config.selectors.get(&GadgetType::Relu).unwrap()[0] + } +} + +impl Gadget for ReluChip { + fn name(&self) -> String { + "Relu".to_string() + } + + fn num_cols_per_op(&self) -> usize { + as NonLinearGadget>::num_cols_per_op() + } + + fn num_inputs_per_row(&self) -> usize { + self.config.columns.len() / self.num_cols_per_op() + } + + fn num_outputs_per_row(&self) -> usize { + self.config.columns.len() / self.num_cols_per_op() + } + + fn load_lookups(&self, layouter: impl Layouter) -> Result<(), Error> { + NonLinearGadget::load_lookups(self, layouter, self.config.clone(), GadgetType::Relu)?; + Ok(()) + } + + fn op_row_region( + &self, + region: &mut Region, + row_offset: usize, + vec_inputs: &Vec>>, + single_inputs: &Vec<&AssignedCell>, + ) -> Result>, Error> { + NonLinearGadget::op_row_region( + self, + region, + row_offset, + vec_inputs, + single_inputs, + self.config.clone(), + ) + } + + fn forward( + &self, + layouter: impl halo2_proofs::circuit::Layouter, + vec_inputs: &Vec>>, + single_inputs: &Vec<&AssignedCell>, + ) -> Result>, Error> { + NonLinearGadget::forward(self, layouter, vec_inputs, single_inputs) + } +} diff --git a/mnist_zkml/src/gadgets/nonlinear/rsqrt.rs b/mnist_zkml/src/gadgets/nonlinear/rsqrt.rs new file mode 100644 index 0000000..eaea6ba --- /dev/null +++ b/mnist_zkml/src/gadgets/nonlinear/rsqrt.rs @@ -0,0 +1,102 @@ +use std::{collections::HashMap, marker::PhantomData, rc::Rc}; + +use halo2_proofs::{ + circuit::{AssignedCell, Layouter, Region}, + halo2curves::ff::PrimeField, + plonk::{ConstraintSystem, Error}, +}; + +use super::{ + super::gadget::{Gadget, GadgetConfig, GadgetType}, + non_linearity::NonLinearGadget, +}; + +pub struct RsqrtGadgetChip { + config: Rc, + _marker: PhantomData, +} + +impl RsqrtGadgetChip { + pub fn construct(config: Rc) -> Self { + Self { + config, + _marker: PhantomData, + } + } + + pub fn configure(meta: &mut ConstraintSystem, gadget_config: GadgetConfig) -> GadgetConfig { + as NonLinearGadget>::configure(meta, gadget_config, GadgetType::Rsqrt) + } +} + +impl NonLinearGadget for RsqrtGadgetChip { + fn generate_map(scale_factor: u64, min_val: i64, num_rows: i64) -> HashMap { + let mut map = HashMap::new(); + for i in 0..num_rows { + let shifted = i + min_val; + let x = (shifted as f64) / (scale_factor as f64); + let sqrt = x.sqrt(); + let rsqrt = 1.0 / sqrt; + let rsqrt = (rsqrt * (scale_factor as f64)).round() as i64; + map.insert(i as i64, rsqrt); + } + map + } + + fn get_map(&self) -> &HashMap { + &self.config.maps.get(&GadgetType::Rsqrt).unwrap()[0] + } + + fn get_selector(&self) -> halo2_proofs::plonk::Selector { + self.config.selectors.get(&GadgetType::Rsqrt).unwrap()[0] + } +} + +impl Gadget for RsqrtGadgetChip { + fn name(&self) -> String { + "RsqrtGadget".to_string() + } + + fn num_cols_per_op(&self) -> usize { + as NonLinearGadget>::num_cols_per_op() + } + + fn num_inputs_per_row(&self) -> usize { + self.config.columns.len() / self.num_cols_per_op() + } + + fn num_outputs_per_row(&self) -> usize { + self.config.columns.len() / self.num_cols_per_op() + } + + fn load_lookups(&self, layouter: impl Layouter) -> Result<(), Error> { + NonLinearGadget::load_lookups(self, layouter, self.config.clone(), GadgetType::Rsqrt)?; + Ok(()) + } + + fn op_row_region( + &self, + region: &mut Region, + row_offset: usize, + vec_inputs: &Vec>>, + single_inputs: &Vec<&AssignedCell>, + ) -> Result>, Error> { + NonLinearGadget::op_row_region( + self, + region, + row_offset, + vec_inputs, + single_inputs, + self.config.clone(), + ) + } + + fn forward( + &self, + layouter: impl halo2_proofs::circuit::Layouter, + vec_inputs: &Vec>>, + single_inputs: &Vec<&AssignedCell>, + ) -> Result>, Error> { + NonLinearGadget::forward(self, layouter, vec_inputs, single_inputs) + } +} diff --git a/mnist_zkml/src/gadgets/nonlinear/sqrt.rs b/mnist_zkml/src/gadgets/nonlinear/sqrt.rs new file mode 100644 index 0000000..7faeecb --- /dev/null +++ b/mnist_zkml/src/gadgets/nonlinear/sqrt.rs @@ -0,0 +1,101 @@ +use std::{collections::HashMap, marker::PhantomData, rc::Rc}; + +use halo2_proofs::{ + circuit::{AssignedCell, Layouter, Region}, + halo2curves::ff::PrimeField, + plonk::{ConstraintSystem, Error}, +}; + +use super::{ + super::gadget::{Gadget, GadgetConfig, GadgetType}, + non_linearity::NonLinearGadget, +}; + +pub struct SqrtGadgetChip { + config: Rc, + _marker: PhantomData, +} + +impl SqrtGadgetChip { + pub fn construct(config: Rc) -> Self { + Self { + config, + _marker: PhantomData, + } + } + + pub fn configure(meta: &mut ConstraintSystem, gadget_config: GadgetConfig) -> GadgetConfig { + as NonLinearGadget>::configure(meta, gadget_config, GadgetType::Sqrt) + } +} + +impl NonLinearGadget for SqrtGadgetChip { + fn generate_map(scale_factor: u64, min_val: i64, num_rows: i64) -> HashMap { + let mut map = HashMap::new(); + for i in 0..num_rows { + let shifted = i + min_val; + let x = (shifted as f64) / (scale_factor as f64); + let sqrt = x.sqrt(); + let sqrt = (sqrt * (scale_factor as f64)).round() as i64; + map.insert(i as i64, sqrt); + } + map + } + + fn get_map(&self) -> &HashMap { + &self.config.maps.get(&GadgetType::Sqrt).unwrap()[0] + } + + fn get_selector(&self) -> halo2_proofs::plonk::Selector { + self.config.selectors.get(&GadgetType::Sqrt).unwrap()[0] + } +} + +impl Gadget for SqrtGadgetChip { + fn name(&self) -> String { + "SqrtGadget".to_string() + } + + fn num_cols_per_op(&self) -> usize { + as NonLinearGadget>::num_cols_per_op() + } + + fn num_inputs_per_row(&self) -> usize { + self.config.columns.len() / self.num_cols_per_op() + } + + fn num_outputs_per_row(&self) -> usize { + self.config.columns.len() / self.num_cols_per_op() + } + + fn load_lookups(&self, layouter: impl Layouter) -> Result<(), Error> { + NonLinearGadget::load_lookups(self, layouter, self.config.clone(), GadgetType::Sqrt)?; + Ok(()) + } + + fn op_row_region( + &self, + region: &mut Region, + row_offset: usize, + vec_inputs: &Vec>>, + single_inputs: &Vec<&AssignedCell>, + ) -> Result>, Error> { + NonLinearGadget::op_row_region( + self, + region, + row_offset, + vec_inputs, + single_inputs, + self.config.clone(), + ) + } + + fn forward( + &self, + layouter: impl halo2_proofs::circuit::Layouter, + vec_inputs: &Vec>>, + single_inputs: &Vec<&AssignedCell>, + ) -> Result>, Error> { + NonLinearGadget::forward(self, layouter, vec_inputs, single_inputs) + } +} diff --git a/mnist_zkml/src/gadgets/nonlinear/tanh.rs b/mnist_zkml/src/gadgets/nonlinear/tanh.rs new file mode 100644 index 0000000..1afa94b --- /dev/null +++ b/mnist_zkml/src/gadgets/nonlinear/tanh.rs @@ -0,0 +1,104 @@ +use std::{collections::HashMap, marker::PhantomData, rc::Rc}; + +use halo2_proofs::{ + circuit::{AssignedCell, Layouter, Region}, + halo2curves::ff::PrimeField, + plonk::{ConstraintSystem, Error}, +}; + +use super::{ + super::gadget::{Gadget, GadgetConfig, GadgetType}, + non_linearity::NonLinearGadget, +}; + +pub struct TanhGadgetChip { + config: Rc, + _marker: PhantomData, +} + +impl TanhGadgetChip { + pub fn construct(config: Rc) -> Self { + Self { + config, + _marker: PhantomData, + } + } + + pub fn configure(meta: &mut ConstraintSystem, gadget_config: GadgetConfig) -> GadgetConfig { + as NonLinearGadget>::configure(meta, gadget_config, GadgetType::Tanh) + } +} + +impl NonLinearGadget for TanhGadgetChip { + fn generate_map(scale_factor: u64, min_val: i64, num_rows: i64) -> HashMap { + let scale_factor = scale_factor as f64; + + let mut map = HashMap::new(); + for i in 0..num_rows { + let shifted = i + min_val; + let x = (shifted as f64) / scale_factor; + let y = x.tanh(); + let y = (y * scale_factor).round() as i64; + map.insert(i as i64, y); + } + + map + } + + fn get_map(&self) -> &HashMap { + &self.config.maps.get(&GadgetType::Tanh).unwrap()[0] + } + + fn get_selector(&self) -> halo2_proofs::plonk::Selector { + self.config.selectors.get(&GadgetType::Tanh).unwrap()[0] + } +} + +impl Gadget for TanhGadgetChip { + fn name(&self) -> String { + "TanhGadgetChip".to_string() + } + + fn num_cols_per_op(&self) -> usize { + as NonLinearGadget>::num_cols_per_op() + } + + fn num_inputs_per_row(&self) -> usize { + self.config.columns.len() / self.num_cols_per_op() + } + + fn num_outputs_per_row(&self) -> usize { + self.config.columns.len() / self.num_cols_per_op() + } + + fn load_lookups(&self, layouter: impl Layouter) -> Result<(), Error> { + NonLinearGadget::load_lookups(self, layouter, self.config.clone(), GadgetType::Tanh)?; + Ok(()) + } + + fn op_row_region( + &self, + region: &mut Region, + row_offset: usize, + vec_inputs: &Vec>>, + single_inputs: &Vec<&AssignedCell>, + ) -> Result>, Error> { + NonLinearGadget::op_row_region( + self, + region, + row_offset, + vec_inputs, + single_inputs, + self.config.clone(), + ) + } + + fn forward( + &self, + layouter: impl halo2_proofs::circuit::Layouter, + vec_inputs: &Vec>>, + single_inputs: &Vec<&AssignedCell>, + ) -> Result>, Error> { + NonLinearGadget::forward(self, layouter, vec_inputs, single_inputs) + } +} diff --git a/mnist_zkml/src/gadgets/sqrt_big.rs b/mnist_zkml/src/gadgets/sqrt_big.rs new file mode 100644 index 0000000..a44e3ed --- /dev/null +++ b/mnist_zkml/src/gadgets/sqrt_big.rs @@ -0,0 +1,194 @@ +use std::{marker::PhantomData, rc::Rc}; + +use halo2_proofs::{ + circuit::{AssignedCell, Layouter, Region}, + halo2curves::ff::PrimeField, + plonk::{ConstraintSystem, Error, Expression}, + poly::Rotation, +}; + +use crate::gadgets::gadget::convert_to_u64; + +use super::gadget::{Gadget, GadgetConfig, GadgetType}; + +type SqrtBigConfig = GadgetConfig; + +pub struct SqrtBigChip { + config: Rc, + _marker: PhantomData, +} + +impl SqrtBigChip { + pub fn construct(config: Rc) -> Self { + Self { + config, + _marker: PhantomData, + } + } + + pub fn num_cols_per_op() -> usize { + 3 + } + + pub fn configure(meta: &mut ConstraintSystem, gadget_config: GadgetConfig) -> GadgetConfig { + let selector = meta.complex_selector(); + let two = Expression::Constant(F::from(2)); + let columns = gadget_config.columns; + + let tables = gadget_config.tables; + + let inp_lookup = tables.get(&GadgetType::InputLookup).unwrap()[0]; + + // TODO: prove that these constraints work + meta.create_gate("sqrt_big arithm", |meta| { + let s = meta.query_selector(selector); + + let mut constraints = vec![]; + for op_idx in 0..columns.len() / Self::num_cols_per_op() { + let offset = op_idx * Self::num_cols_per_op(); + let inp = meta.query_advice(columns[offset + 0], Rotation::cur()); + let sqrt = meta.query_advice(columns[offset + 1], Rotation::cur()); + let rem = meta.query_advice(columns[offset + 2], Rotation::cur()); + + let lhs = inp.clone(); + let rhs = sqrt.clone() * sqrt.clone() + rem.clone(); + constraints.push(s.clone() * (lhs - rhs)); + } + constraints + }); + + for op_idx in 0..columns.len() / Self::num_cols_per_op() { + let offset = op_idx * Self::num_cols_per_op(); + meta.lookup("sqrt_big sqrt lookup", |meta| { + let s = meta.query_selector(selector); + let sqrt = meta.query_advice(columns[offset + 1], Rotation::cur()); + + vec![(s.clone() * sqrt, inp_lookup)] + }); + + meta.lookup("sqrt_big rem lookup", |meta| { + let s = meta.query_selector(selector); + let sqrt = meta.query_advice(columns[offset + 1], Rotation::cur()); + let rem = meta.query_advice(columns[offset + 2], Rotation::cur()); + + vec![(s.clone() * (rem + sqrt), inp_lookup)] + }); + + meta.lookup("sqrt_big sqrt - rem lookup", |meta| { + let s = meta.query_selector(selector); + let sqrt = meta.query_advice(columns[offset + 1], Rotation::cur()); + let rem = meta.query_advice(columns[offset + 2], Rotation::cur()); + + vec![(s.clone() * (two.clone() * sqrt - rem), inp_lookup)] + }); + } + + let mut selectors = gadget_config.selectors; + selectors.insert(GadgetType::SqrtBig, vec![selector]); + + GadgetConfig { + columns, + tables, + selectors, + ..gadget_config + } + } +} + +impl Gadget for SqrtBigChip { + fn name(&self) -> String { + "sqrt_big".to_string() + } + + fn num_cols_per_op(&self) -> usize { + Self::num_cols_per_op() + } + + fn num_inputs_per_row(&self) -> usize { + self.config.columns.len() / self.num_cols_per_op() + } + + fn num_outputs_per_row(&self) -> usize { + self.num_inputs_per_row() + } + + fn op_row_region( + &self, + region: &mut Region, + row_offset: usize, + vec_inputs: &Vec>>, + _single_inputs: &Vec<&AssignedCell>, + ) -> Result>, Error> { + let inps = &vec_inputs[0]; + + if self.config.use_selectors { + let selector = self.config.selectors.get(&GadgetType::SqrtBig).unwrap()[0]; + selector.enable(region, row_offset)?; + } + + let mut outp_cells = vec![]; + for (i, inp) in inps.iter().enumerate() { + let offset = i * self.num_cols_per_op(); + inp.copy_advice( + || "sqrt_big", + region, + self.config.columns[offset], + row_offset, + )?; + + let outp = inp.value().map(|x: &F| { + let inp_val = convert_to_u64(x) as i64; + let fsqrt = (inp_val as f64).sqrt(); + let sqrt = fsqrt.round() as i64; + let rem = inp_val - sqrt * sqrt; + (sqrt, rem) + }); + + let sqrt_cell = region.assign_advice( + || "sqrt_big", + self.config.columns[offset + 1], + row_offset, + || outp.map(|x| F::from(x.0 as u64)), + )?; + + let _rem_cell = region.assign_advice( + || "sqrt_big", + self.config.columns[offset + 2], + row_offset, + || { + outp.map(|x| { + let rem_pos = x.1 + x.0; + F::from(rem_pos as u64) - F::from(x.0 as u64) + }) + }, + )?; + outp_cells.push(sqrt_cell); + } + + Ok(outp_cells) + } + + fn forward( + &self, + mut layouter: impl Layouter, + vec_inputs: &Vec>>, + single_inputs: &Vec<&AssignedCell>, + ) -> Result>, Error> { + let zero = &single_inputs[0]; + + let mut inp = vec_inputs[0].clone(); + let inp_len = inp.len(); + while inp.len() % self.num_inputs_per_row() != 0 { + inp.push(zero); + } + + let vec_inputs = vec![inp]; + let outp = self.op_aligned_rows( + layouter.namespace(|| format!("forward row {}", self.name())), + &vec_inputs, + single_inputs, + )?; + + Ok(outp[0..inp_len].to_vec()) + } +} diff --git a/mnist_zkml/src/gadgets/square.rs b/mnist_zkml/src/gadgets/square.rs new file mode 100644 index 0000000..9f13742 --- /dev/null +++ b/mnist_zkml/src/gadgets/square.rs @@ -0,0 +1,122 @@ +use std::{marker::PhantomData, rc::Rc}; + +use halo2_proofs::{ + circuit::{AssignedCell, Region}, + halo2curves::ff::PrimeField, + plonk::{ConstraintSystem, Error}, + poly::Rotation, +}; + +use super::gadget::{Gadget, GadgetConfig, GadgetType}; + +pub struct SquareGadgetChip { + config: Rc, + _marker: PhantomData, +} + +impl SquareGadgetChip { + pub fn construct(config: Rc) -> Self { + Self { + config, + _marker: PhantomData, + } + } + + // TODO: it would be more efficient to do the division here directly + pub fn configure(meta: &mut ConstraintSystem, gadget_config: GadgetConfig) -> GadgetConfig { + let selector = meta.selector(); + let columns = gadget_config.columns; + + meta.create_gate("square gate", |meta| { + let s = meta.query_selector(selector); + let gate_inp = meta.query_advice(columns[0], Rotation::cur()); + let gate_output = meta.query_advice(columns[1], Rotation::cur()); + + let res = gate_inp.clone() * gate_inp; + + vec![s * (res - gate_output)] + }); + + let mut selectors = gadget_config.selectors; + selectors.insert(GadgetType::Square, vec![selector]); + + GadgetConfig { + columns, + selectors, + ..gadget_config + } + } +} + +impl Gadget for SquareGadgetChip { + fn name(&self) -> String { + "SquareChip".to_string() + } + + fn num_cols_per_op(&self) -> usize { + 2 + } + + fn num_inputs_per_row(&self) -> usize { + self.config.columns.len() / self.num_cols_per_op() + } + + fn num_outputs_per_row(&self) -> usize { + self.num_inputs_per_row() + } + + fn op_row_region( + &self, + region: &mut Region, + row_offset: usize, + vec_inputs: &Vec>>, + _single_inputs: &Vec<&AssignedCell>, + ) -> Result>, Error> { + assert_eq!(vec_inputs.len(), 1); + + if self.config.use_selectors { + let selector = self.config.selectors.get(&GadgetType::Square).unwrap()[0]; + selector.enable(region, row_offset)?; + } + + let inps = &vec_inputs[0]; + let mut outp = vec![]; + for (i, inp) in inps.iter().enumerate() { + let offset = i * self.num_cols_per_op(); + inp.copy_advice(|| "", region, self.config.columns[offset], row_offset)?; + let outp_val = inp.value().map(|x: &F| x.to_owned() * x.to_owned()); + let outp_cell = region.assign_advice( + || "square output", + self.config.columns[offset + 1], + row_offset, + || outp_val, + )?; + outp.push(outp_cell); + } + + Ok(outp) + } + + fn forward( + &self, + mut layouter: impl halo2_proofs::circuit::Layouter, + vec_inputs: &Vec>>, + single_inputs: &Vec<&AssignedCell>, + ) -> Result>, Error> { + let zero = &single_inputs[0]; + + let mut inp = vec_inputs[0].clone(); + let initial_len = inp.len(); + while inp.len() % self.num_inputs_per_row() != 0 { + inp.push(zero); + } + + let vec_inputs = vec![inp]; + let res = self.op_aligned_rows( + layouter.namespace(|| format!("forward row {}", self.name())), + &vec_inputs, + single_inputs, + )?; + Ok(res[0..initial_len].to_vec()) + } +} diff --git a/mnist_zkml/src/gadgets/squared_diff.rs b/mnist_zkml/src/gadgets/squared_diff.rs new file mode 100644 index 0000000..825542c --- /dev/null +++ b/mnist_zkml/src/gadgets/squared_diff.rs @@ -0,0 +1,137 @@ +use std::{marker::PhantomData, rc::Rc}; + +use halo2_proofs::{ + circuit::{AssignedCell, Layouter, Region}, + halo2curves::ff::PrimeField, + plonk::{ConstraintSystem, Error}, + poly::Rotation, +}; + +use super::gadget::{Gadget, GadgetConfig, GadgetType}; + +type SquaredDiffConfig = GadgetConfig; + +pub struct SquaredDiffGadgetChip { + config: Rc, + _marker: PhantomData, +} + +impl SquaredDiffGadgetChip { + pub fn construct(config: Rc) -> Self { + Self { + config, + _marker: PhantomData, + } + } + + pub fn num_cols_per_op() -> usize { + 3 + } + + pub fn configure(meta: &mut ConstraintSystem, gadget_config: GadgetConfig) -> GadgetConfig { + let selector = meta.selector(); + let columns = gadget_config.columns; + + meta.create_gate("squared diff", |meta| { + let s = meta.query_selector(selector); + let mut constraints = vec![]; + for i in 0..columns.len() / Self::num_cols_per_op() { + let offset = i * Self::num_cols_per_op(); + let inp1 = meta.query_advice(columns[offset + 0], Rotation::cur()); + let inp2 = meta.query_advice(columns[offset + 1], Rotation::cur()); + let outp = meta.query_advice(columns[offset + 2], Rotation::cur()); + + let res = (inp1 - inp2).square(); + constraints.append(&mut vec![s.clone() * (res - outp)]) + } + + constraints + }); + + let mut selectors = gadget_config.selectors; + selectors.insert(GadgetType::SquaredDiff, vec![selector]); + + GadgetConfig { + columns, + selectors, + ..gadget_config + } + } +} + +impl Gadget for SquaredDiffGadgetChip { + fn name(&self) -> String { + "SquaredDiff".to_string() + } + + fn num_cols_per_op(&self) -> usize { + Self::num_cols_per_op() + } + + fn num_inputs_per_row(&self) -> usize { + self.config.columns.len() / self.num_cols_per_op() + } + + fn num_outputs_per_row(&self) -> usize { + self.config.columns.len() / self.num_cols_per_op() + } + + fn op_row_region( + &self, + region: &mut Region, + row_offset: usize, + vec_inputs: &Vec>>, + _single_inputs: &Vec<&AssignedCell>, + ) -> Result>, Error> { + let inp1 = &vec_inputs[0]; + let inp2 = &vec_inputs[1]; + assert_eq!(inp1.len(), inp2.len()); + + let columns = &self.config.columns; + + if self.config.use_selectors { + let selector = self.config.selectors.get(&GadgetType::SquaredDiff).unwrap()[0]; + selector.enable(region, row_offset)?; + } + + let mut outps = vec![]; + for i in 0..inp1.len() { + let offset = i * self.num_cols_per_op(); + let inp1 = inp1[i].copy_advice(|| "", region, columns[offset + 0], row_offset)?; + let inp2 = inp2[i].copy_advice(|| "", region, columns[offset + 1], row_offset)?; + let outp = inp1.value().map(|x: &F| x.to_owned()) - inp2.value().map(|x: &F| x.to_owned()); + let outp = outp * outp; + + let outp = region.assign_advice(|| "", columns[offset + 2], row_offset, || outp)?; + outps.push(outp); + } + Ok(outps) + } + + fn forward( + &self, + mut layouter: impl Layouter, + vec_inputs: &Vec>>, + single_inputs: &Vec<&AssignedCell>, + ) -> Result>, Error> { + let zero = &single_inputs[0]; + + let mut inp1 = vec_inputs[0].clone(); + let mut inp2 = vec_inputs[1].clone(); + let initial_len = inp1.len(); + while inp1.len() % self.num_inputs_per_row() != 0 { + inp1.push(zero); + inp2.push(zero); + } + + let vec_inputs = vec![inp1, inp2]; + + let res = self.op_aligned_rows( + layouter.namespace(|| format!("forward row {}", self.name())), + &vec_inputs, + single_inputs, + )?; + + Ok(res[0..initial_len].to_vec()) + } +} diff --git a/mnist_zkml/src/gadgets/sub_pairs.rs b/mnist_zkml/src/gadgets/sub_pairs.rs new file mode 100644 index 0000000..2b3aaca --- /dev/null +++ b/mnist_zkml/src/gadgets/sub_pairs.rs @@ -0,0 +1,135 @@ +use std::{marker::PhantomData, rc::Rc}; + +use halo2_proofs::{ + circuit::{AssignedCell, Layouter, Region}, + halo2curves::ff::PrimeField, + plonk::{ConstraintSystem, Error}, + poly::Rotation, +}; + +use super::gadget::{Gadget, GadgetConfig, GadgetType}; + +type SubPairsConfig = GadgetConfig; + +pub struct SubPairsChip { + config: Rc, + _marker: PhantomData, +} + +impl SubPairsChip { + pub fn construct(config: Rc) -> Self { + Self { + config, + _marker: PhantomData, + } + } + + pub fn num_cols_per_op() -> usize { + 3 + } + + pub fn configure(meta: &mut ConstraintSystem, gadget_config: GadgetConfig) -> GadgetConfig { + let selector = meta.selector(); + let columns = gadget_config.columns; + + meta.create_gate("sub pair", |meta| { + let s = meta.query_selector(selector); + let mut constraints = vec![]; + for i in 0..columns.len() / Self::num_cols_per_op() { + let offset = i * Self::num_cols_per_op(); + let inp1 = meta.query_advice(columns[offset + 0], Rotation::cur()); + let inp2 = meta.query_advice(columns[offset + 1], Rotation::cur()); + let outp = meta.query_advice(columns[offset + 2], Rotation::cur()); + + let res = inp1 - inp2; + constraints.append(&mut vec![s.clone() * (res - outp)]) + } + + constraints + }); + + let mut selectors = gadget_config.selectors; + selectors.insert(GadgetType::SubPairs, vec![selector]); + + GadgetConfig { + columns, + selectors, + ..gadget_config + } + } +} + +impl Gadget for SubPairsChip { + fn name(&self) -> String { + "sub pairs chip".to_string() + } + + fn num_cols_per_op(&self) -> usize { + Self::num_cols_per_op() + } + + fn num_inputs_per_row(&self) -> usize { + self.config.columns.len() / self.num_cols_per_op() + } + + fn num_outputs_per_row(&self) -> usize { + self.config.columns.len() / self.num_cols_per_op() + } + + fn op_row_region( + &self, + region: &mut Region, + row_offset: usize, + vec_inputs: &Vec>>, + _single_inputs: &Vec<&AssignedCell>, + ) -> Result>, Error> { + let inp1 = &vec_inputs[0]; + let inp2 = &vec_inputs[1]; + assert_eq!(inp1.len(), inp2.len()); + + let columns = &self.config.columns; + + if self.config.use_selectors { + let selector = self.config.selectors.get(&GadgetType::SubPairs).unwrap()[0]; + selector.enable(region, row_offset)?; + } + + let mut outps = vec![]; + for i in 0..inp1.len() { + let offset = i * self.num_cols_per_op(); + let inp1 = inp1[i].copy_advice(|| "", region, columns[offset + 0], row_offset)?; + let inp2 = inp2[i].copy_advice(|| "", region, columns[offset + 1], row_offset)?; + let outp = inp1.value().map(|x: &F| x.to_owned()) - inp2.value().map(|x: &F| x.to_owned()); + + let outp = region.assign_advice(|| "", columns[offset + 2], row_offset, || outp)?; + outps.push(outp); + } + Ok(outps) + } + + fn forward( + &self, + mut layouter: impl Layouter, + vec_inputs: &Vec>>, + single_inputs: &Vec<&AssignedCell>, + ) -> Result>, Error> { + let zero = &single_inputs[0]; + + let mut inp1 = vec_inputs[0].clone(); + let mut inp2 = vec_inputs[1].clone(); + let initial_len = inp1.len(); + while inp1.len() % self.num_inputs_per_row() != 0 { + inp1.push(zero); + inp2.push(zero); + } + + let vec_inputs = vec![inp1, inp2]; + + let res = self.op_aligned_rows( + layouter.namespace(|| format!("forward row {}", self.name())), + &vec_inputs, + single_inputs, + )?; + Ok(res[0..initial_len].to_vec()) + } +} diff --git a/mnist_zkml/src/gadgets/update.rs b/mnist_zkml/src/gadgets/update.rs new file mode 100644 index 0000000..06338a1 --- /dev/null +++ b/mnist_zkml/src/gadgets/update.rs @@ -0,0 +1,209 @@ +use std::marker::PhantomData; + +use halo2_proofs::{ + circuit::{AssignedCell, Layouter, Region}, + halo2curves::ff::PrimeField, + plonk::{ConstraintSystem, Error, Expression}, + poly::Rotation, +}; + +use crate::gadgets::gadget::{convert_to_u64, GadgetConfig}; + +use super::gadget::{Gadget, GadgetType}; + +type UpdateConfig = GadgetConfig; + +#[derive(Clone, Debug)] +pub struct UpdateGadgetChip { + config: UpdateConfig, + _marker: PhantomData, +} + +impl UpdateGadgetChip { + pub fn construct(config: UpdateConfig) -> Self { + Self { + config, + _marker: PhantomData, + } + } + + pub fn num_cols_per_op() -> usize { + 4 + } + + pub fn configure(meta: &mut ConstraintSystem, gadget_config: GadgetConfig) -> UpdateConfig { + let tables = &gadget_config.tables; + let mod_lookup = tables.get(&GadgetType::InputLookup).unwrap()[0]; + + let columns = gadget_config.columns; + let selector = meta.complex_selector(); + + let div_val = gadget_config.scale_factor; + let eta: u64 = (gadget_config.scale_factor as f64 * gadget_config.eta) as u64; + + meta.create_gate("updater_arith", |meta| { + let s = meta.query_selector(selector); + + let sf = Expression::Constant(F::from(div_val as u64)); + let eta = Expression::Constant(F::from(eta as u64)); + + let mut constraints = vec![]; + for op_idx in 0..columns.len() / Self::num_cols_per_op() { + let offset = op_idx * Self::num_cols_per_op(); + let w = meta.query_advice(columns[offset], Rotation::cur()); + let dw = meta.query_advice(columns[offset + 1], Rotation::cur()); + let div = meta.query_advice(columns[offset + 2], Rotation::cur()); + let mod_res = meta.query_advice(columns[offset + 3], Rotation::cur()); + + let expr = (w * sf.clone() - dw * eta.clone()) - (div * sf.clone() + mod_res); + constraints.push(s.clone() * expr); + } + constraints + }); + + for op_idx in 0..columns.len() / Self::num_cols_per_op() { + let offset = op_idx * Self::num_cols_per_op(); + + // Check that mod is smaller than SF + meta.lookup("max inp1", |meta| { + let s = meta.query_selector(selector); + let mod_res = meta.query_advice(columns[offset + 3], Rotation::cur()); + + // Constrains that the modulus \in [0, DIV_VAL) + vec![(s.clone() * mod_res.clone(), mod_lookup)] + }); + } + + let mut selectors = gadget_config.selectors; + selectors.insert(GadgetType::Update, vec![selector]); + + UpdateConfig { + columns, + selectors, + ..gadget_config + } + } +} + +impl Gadget for UpdateGadgetChip { + fn name(&self) -> String { + "updater chip".to_string() + } + + fn num_cols_per_op(&self) -> usize { + Self::num_cols_per_op() + } + + fn num_inputs_per_row(&self) -> usize { + self.config.columns.len() / self.num_cols_per_op() + } + + fn num_outputs_per_row(&self) -> usize { + self.config.columns.len() / self.num_cols_per_op() + } + + fn op_row_region( + &self, + region: &mut Region, + row_offset: usize, + vec_inputs: &Vec>>, + _single_inputs: &Vec<&AssignedCell>, + ) -> Result>, Error> { + let div_val = self.config.scale_factor as i64; + let div_val_f = F::from(div_val as u64); + let eta = div_val / 1000; + let eta = F::from(eta as u64); + + let div_outp_min_val = self.config.div_outp_min_val; + let div_inp_min_val_pos_i64 = -self.config.shift_min_val; + let div_inp_min_val_pos = F::from(div_inp_min_val_pos_i64 as u64); + + let columns = &self.config.columns; + + if self.config.use_selectors { + let selector = self.config.selectors.get(&GadgetType::Update).unwrap()[0]; + selector.enable(region, row_offset)?; + } + + let w = &vec_inputs[0]; + let dw = &vec_inputs[1]; + + let mut output_cells = vec![]; + + for i in 0..w.len() { + let offset = i * self.num_cols_per_op(); + let _w_cell = w[i].copy_advice(|| "", region, columns[offset + 0], row_offset)?; + let _dw_cell = dw[i].copy_advice(|| "", region, columns[offset + 1], row_offset)?; + + let w_val = w[i].value().map(|x: &F| x.to_owned()); + let dw_val = dw[i].value().map(|x: &F| x.to_owned()); + let out_scaled = w_val.zip(dw_val).map(|(w, dw)| w * div_val_f - dw * eta); + + let div_mod = out_scaled.map(|x| { + let x_pos = x + div_inp_min_val_pos; + let x_pos = if x_pos > F::ZERO { + x_pos + } else { + x_pos + div_val_f + }; + let inp = convert_to_u64(&x_pos); + + let div_res = inp as i64 / div_val - (div_inp_min_val_pos_i64 as i64 / div_val); + let mod_res = inp as i64 % div_val; + (div_res, mod_res) + }); + + let div_res_cell = region + .assign_advice( + || "div_res", + self.config.columns[offset + 2], + row_offset, + || { + div_mod.map(|(x, _): (i64, i64)| { + F::from((x - div_outp_min_val as i64) as u64) - F::from(-div_outp_min_val as u64) + }) + }, + ) + .unwrap(); + + let _mod_res_cell = region + .assign_advice( + || "mod_res", + self.config.columns[offset + 3], + row_offset, + || div_mod.map(|(_, x): (i64, i64)| F::from(x as u64)), + ) + .unwrap(); + + output_cells.push(div_res_cell); + } + Ok(output_cells) + } + + fn forward( + &self, + mut layouter: impl Layouter, + vec_inputs: &Vec>>, + single_inputs: &Vec<&AssignedCell>, + ) -> Result>, Error> { + let zero = &single_inputs[0]; + let mut w = vec_inputs[0].clone(); + let mut dw = vec_inputs[1].clone(); + + let initial_len = w.len(); + while !w.len() % self.num_cols_per_op() == 0 { + w.push(zero); + } + while !dw.len() % self.num_cols_per_op() == 0 { + dw.push(zero); + } + + let res = self.op_aligned_rows( + layouter.namespace(|| format!("forward row {}", self.name())), + &vec![w, dw], + single_inputs, + )?; + + Ok(res[0..initial_len].to_vec()) + } +} diff --git a/mnist_zkml/src/gadgets/var_div.rs b/mnist_zkml/src/gadgets/var_div.rs new file mode 100644 index 0000000..b58fbfa --- /dev/null +++ b/mnist_zkml/src/gadgets/var_div.rs @@ -0,0 +1,210 @@ +use std::{marker::PhantomData, rc::Rc}; + +use halo2_proofs::{ + circuit::{AssignedCell, Layouter, Region}, + halo2curves::ff::PrimeField, + plonk::{ConstraintSystem, Error, Expression}, + poly::Rotation, +}; +use rounded_div::RoundedDiv; + +use super::gadget::{convert_to_u128, Gadget, GadgetConfig, GadgetType}; + +type VarDivRoundConfig = GadgetConfig; + +pub struct VarDivRoundChip { + config: Rc, + _marker: PhantomData, +} + +impl VarDivRoundChip { + pub fn construct(config: Rc) -> Self { + Self { + config, + _marker: PhantomData, + } + } + + pub fn num_cols_per_op() -> usize { + 3 + } + + pub fn configure(meta: &mut ConstraintSystem, gadget_config: GadgetConfig) -> GadgetConfig { + let columns = gadget_config.columns; + let selector = meta.complex_selector(); + let two = Expression::Constant(F::from(2)); + + let tables = gadget_config.tables; + let lookup = tables.get(&GadgetType::InputLookup).unwrap()[0]; + + // a | c | r | ... | b + // (2 * a + b) = (2 * b) * c + r + // b - r \in [0, 2^N) <-- forces b > r + meta.create_gate("var_div_arithm", |meta| { + let s = meta.query_selector(selector); + let mut constraints = vec![]; + + let b = meta.query_advice(columns[columns.len() - 1], Rotation::cur()); + for i in 0..(columns.len() - 1) / Self::num_cols_per_op() { + let offset = i * Self::num_cols_per_op(); + let a = meta.query_advice(columns[offset], Rotation::cur()); + let c = meta.query_advice(columns[offset + 1], Rotation::cur()); + let r = meta.query_advice(columns[offset + 2], Rotation::cur()); + + let lhs = a.clone() * two.clone() + b.clone(); + let rhs = b.clone() * two.clone() * c + r; + constraints.push(s.clone() * (lhs - rhs)); + } + + constraints + }); + + for i in 0..(columns.len() - 1) / Self::num_cols_per_op() { + let offset = i * Self::num_cols_per_op(); + // r \in [0, 2^N) + meta.lookup("var div range checks r", |meta| { + let s = meta.query_selector(selector); + let r = meta.query_advice(columns[offset + 2], Rotation::cur()); + + vec![(s.clone() * r, lookup)] + }); + + // 2 * b - r \in [0, 2^N) + meta.lookup("var div range checks 2b-r", |meta| { + let s = meta.query_selector(selector); + let b = meta.query_advice(columns[columns.len() - 1], Rotation::cur()); + let r = meta.query_advice(columns[offset + 2], Rotation::cur()); + + vec![(s.clone() * (two.clone() * b - r), lookup)] + }); + } + // b \in [0, 2^N) + meta.lookup("var div range checks b", |meta| { + let s = meta.query_selector(selector); + let b = meta.query_advice(columns[columns.len() - 1], Rotation::cur()); + + vec![(s.clone() * b, lookup)] + }); + + let mut selectors = gadget_config.selectors; + selectors.insert(GadgetType::VarDivRound, vec![selector]); + + GadgetConfig { + columns, + tables, + selectors, + ..gadget_config + } + } +} + +impl Gadget for VarDivRoundChip { + fn name(&self) -> String { + "VarDivRoundChip".to_string() + } + + fn num_cols_per_op(&self) -> usize { + Self::num_cols_per_op() + } + + fn num_inputs_per_row(&self) -> usize { + (self.config.columns.len() - 1) / self.num_cols_per_op() + } + + fn num_outputs_per_row(&self) -> usize { + self.num_inputs_per_row() + } + + fn op_row_region( + &self, + region: &mut Region, + row_offset: usize, + vec_inputs: &Vec>>, + single_inputs: &Vec<&AssignedCell>, + ) -> Result>, Error> { + let a_vec = &vec_inputs[0]; + // let zero = single_inputs[0].clone(); + let b = &single_inputs[1]; + + let div_outp_min_val_i64 = self.config.div_outp_min_val; + let div_inp_min_val_pos_i64 = -self.config.shift_min_val; + + if self.config.use_selectors { + let selector = self.config.selectors.get(&GadgetType::VarDivRound).unwrap()[0]; + selector.enable(region, row_offset)?; + } + + b.copy_advice( + || "", + region, + self.config.columns[self.config.columns.len() - 1], + row_offset, + )?; + + let mut div_out = vec![]; + for (i, a) in a_vec.iter().enumerate() { + let offset = i * self.num_cols_per_op(); + a.copy_advice(|| "", region, self.config.columns[offset], row_offset)?; + + let div_mod = a.value().zip(b.value()).map(|(a, b)| { + let b = convert_to_u128(b); + // Needs to be divisible by b + let div_inp_min_val_pos_i64 = div_inp_min_val_pos_i64 / (b as i64) * (b as i64); + let div_inp_min_val_pos = F::from(div_inp_min_val_pos_i64 as u64); + + let a_pos = *a + div_inp_min_val_pos; + let a = convert_to_u128(&a_pos); + // c = (2 * a + b) / (2 * b) + let c_pos = a.rounded_div(b); + let c = (c_pos as i128 - (div_inp_min_val_pos_i64 as u128 / b) as i128) as i64; + + // r = (2 * a + b) % (2 * b) + let rem_floor = (a as i128) - (c_pos * b) as i128; + let r = 2 * rem_floor + (b as i128); + let r = r as i64; + (c, r) + }); + + let div_cell = region.assign_advice( + || "", + self.config.columns[offset + 1], + row_offset, + || { + div_mod.map(|(c, _)| { + let offset = F::from(-div_outp_min_val_i64 as u64); + let c = F::from((c - div_outp_min_val_i64) as u64); + c - offset + }) + }, + )?; + let _mod_cell = region.assign_advice( + || "", + self.config.columns[offset + 2], + row_offset, + || div_mod.map(|(_, r)| F::from(r as u64)), + )?; + div_out.push(div_cell); + } + + Ok(div_out) + } + + fn forward( + &self, + mut layouter: impl Layouter, + vec_inputs: &Vec>>, + single_inputs: &Vec<&AssignedCell>, + ) -> Result>, Error> { + let mut inps = vec_inputs[0].clone(); + let initial_len = inps.len(); + + // Needed to pad: bias - bias = 0 + let default = &single_inputs[0]; + while inps.len() % self.num_inputs_per_row() != 0 { + inps.push(&default); + } + + let res = self.op_aligned_rows(layouter.namespace(|| "var_div"), &vec![inps], single_inputs)?; + Ok(res[..initial_len].to_vec()) + } +} diff --git a/mnist_zkml/src/gadgets/var_div_big.rs b/mnist_zkml/src/gadgets/var_div_big.rs new file mode 100644 index 0000000..d1ea412 --- /dev/null +++ b/mnist_zkml/src/gadgets/var_div_big.rs @@ -0,0 +1,278 @@ +use std::{marker::PhantomData, rc::Rc}; + +use halo2_proofs::{ + circuit::{AssignedCell, Layouter, Region}, + halo2curves::ff::PrimeField, + plonk::{ConstraintSystem, Error, Expression}, + poly::Rotation, +}; +use rounded_div::RoundedDiv; + +use super::gadget::{convert_to_u128, Gadget, GadgetConfig, GadgetType}; + +pub struct VarDivRoundBigChip { + config: Rc, + _marker: PhantomData, +} + +impl VarDivRoundBigChip { + pub fn construct(config: Rc) -> Self { + Self { + config, + _marker: PhantomData, + } + } + + pub fn num_cols_per_op() -> usize { + 7 + } + + pub fn configure(meta: &mut ConstraintSystem, gadget_config: GadgetConfig) -> GadgetConfig { + let columns = gadget_config.columns; + let selector = meta.complex_selector(); + let two = Expression::Constant(F::from(2)); + let range = Expression::Constant(F::from(gadget_config.num_rows as u64)); + + let tables = gadget_config.tables; + let lookup = tables.get(&GadgetType::InputLookup).unwrap()[0]; + + // a | c | r | (2 b - r)_1 | (2 b - r)_0 | r_1 | r_0 | ... | b + // a / b = c + meta.create_gate("var_div_arithm", |meta| { + let s = meta.query_selector(selector); + let mut constraints = vec![]; + + let b = meta.query_advice(columns[columns.len() - 1], Rotation::cur()); + for i in 0..(columns.len() - 1) / Self::num_cols_per_op() { + let offset = i * Self::num_cols_per_op(); + // Constrain that (2 * a + b) = (2 * b) * c + r + let a = meta.query_advice(columns[offset], Rotation::cur()); + let c = meta.query_advice(columns[offset + 1], Rotation::cur()); + let r = meta.query_advice(columns[offset + 2], Rotation::cur()); + + let lhs = a.clone() * two.clone() + b.clone(); + let rhs = b.clone() * two.clone() * c + r.clone(); + constraints.push(s.clone() * (lhs - rhs)); + + // Constrain that (2 * b - r) = br1 * max_val + br0 + let br1 = meta.query_advice(columns[offset + 3], Rotation::cur()); + let br0 = meta.query_advice(columns[offset + 4], Rotation::cur()); + let lhs = b.clone() * two.clone() - r.clone(); + let rhs = br1 * range.clone() + br0; + constraints.push(s.clone() * (lhs - rhs)); + + // Constrains that r = r1 * max_val + r0 + let r1 = meta.query_advice(columns[offset + 5], Rotation::cur()); + let r0 = meta.query_advice(columns[offset + 6], Rotation::cur()); + let lhs = r.clone(); + let rhs = r1 * range.clone() + r0; + constraints.push(s.clone() * (lhs - rhs)); + } + + constraints + }); + + // For var div big, we assume that a, b > 0 and are outputs of the previous layer + // r must be constrained to be in [0, b) + for i in 0..(columns.len() - 1) / Self::num_cols_per_op() { + let offset = i * Self::num_cols_per_op(); + + // (2 * b - r)_{1, 0} \in [0, 2^N) + meta.lookup("var div big br1", |meta| { + let s = meta.query_selector(selector); + let br1 = meta.query_advice(columns[offset + 3], Rotation::cur()); + vec![(s * br1, lookup)] + }); + meta.lookup("var div big br0", |meta| { + let s = meta.query_selector(selector); + let br0 = meta.query_advice(columns[offset + 4], Rotation::cur()); + vec![(s * br0, lookup)] + }); + // r_{1, 0} \in [0, 2^N) + meta.lookup("var div big r1", |meta| { + let s = meta.query_selector(selector); + let r1 = meta.query_advice(columns[offset + 5], Rotation::cur()); + vec![(s * r1, lookup)] + }); + meta.lookup("var div big r0", |meta| { + let s = meta.query_selector(selector); + let r0 = meta.query_advice(columns[offset + 6], Rotation::cur()); + vec![(s * r0, lookup)] + }); + } + + let mut selectors = gadget_config.selectors; + selectors.insert(GadgetType::VarDivRoundBig, vec![selector]); + + GadgetConfig { + columns, + tables, + selectors, + ..gadget_config + } + } +} + +impl Gadget for VarDivRoundBigChip { + fn name(&self) -> String { + "VarDivBigRoundChip".to_string() + } + + fn num_cols_per_op(&self) -> usize { + Self::num_cols_per_op() + } + + fn num_inputs_per_row(&self) -> usize { + (self.config.columns.len() - 1) / self.num_cols_per_op() + } + + fn num_outputs_per_row(&self) -> usize { + self.num_inputs_per_row() + } + + fn op_row_region( + &self, + region: &mut Region, + row_offset: usize, + vec_inputs: &Vec>>, + single_inputs: &Vec<&AssignedCell>, + ) -> Result>, Error> { + let a_vec = &vec_inputs[0]; + // let zero = single_inputs[0].clone(); + let b = &single_inputs[1]; + + let div_outp_min_val_i64 = self.config.div_outp_min_val; + let div_inp_min_val_pos_i64 = -self.config.shift_min_val; + let num_rows = self.config.num_rows as i64; + + if self.config.use_selectors { + let selector = self + .config + .selectors + .get(&GadgetType::VarDivRoundBig) + .unwrap()[0]; + selector.enable(region, row_offset)?; + } + + b.copy_advice( + || "", + region, + self.config.columns[self.config.columns.len() - 1], + row_offset, + )?; + + let mut div_out = vec![]; + for (i, a) in a_vec.iter().enumerate() { + let offset = i * self.num_cols_per_op(); + a.copy_advice(|| "", region, self.config.columns[offset], row_offset) + .unwrap(); + + let div_mod = a.value().zip(b.value()).map(|(a, b)| { + let b = convert_to_u128(b); + // Needs to be divisible by b + let div_inp_min_val_pos_i64 = div_inp_min_val_pos_i64 / (b as i64) * (b as i64); + let div_inp_min_val_pos = F::from(div_inp_min_val_pos_i64 as u64); + + let a_pos = *a + div_inp_min_val_pos; + let a = convert_to_u128(&a_pos); + // c = (2 * a + b) / (2 * b) + let c_pos = a.rounded_div(b); + let c = (c_pos as i128 - (div_inp_min_val_pos_i64 as u128 / b) as i128) as i64; + + // r = (2 * a + b) % (2 * b) + let rem_floor = (a as i128) - (c_pos * b) as i128; + let r = 2 * rem_floor + (b as i128); + let r = r as i64; + (c, r) + }); + + let br_split = div_mod.zip(b.value()).map(|((_, r), b)| { + let b = convert_to_u128(b) as i64; + let val = 2 * b - r; + let p1 = val / num_rows; + let p0 = val % num_rows; + // val = p1 * max_val + p0 + (p1, p0) + }); + + let r_split = div_mod.map(|(_, r)| { + let p1 = r / num_rows; + let p0 = r % num_rows; + // val = p1 * max_val + p0 + (p1, p0) + }); + + let div_cell = region.assign_advice( + || "", + self.config.columns[offset + 1], + row_offset, + || { + div_mod.map(|(c, _)| { + let offset = F::from(-div_outp_min_val_i64 as u64); + let c = F::from((c - div_outp_min_val_i64) as u64); + c - offset + }) + }, + )?; + let _mod_cell = region.assign_advice( + || "", + self.config.columns[offset + 2], + row_offset, + || div_mod.map(|(_, r)| F::from(r as u64)), + )?; + // Assign 2 * b - r to the next 2 columns + let _br_split_cell_1 = region.assign_advice( + || "", + self.config.columns[offset + 3], + row_offset, + || br_split.map(|(p1, _)| F::from(p1 as u64)), + )?; + let _br_split_cell_2 = region.assign_advice( + || "", + self.config.columns[offset + 4], + row_offset, + || br_split.map(|(_, p0)| F::from(p0 as u64)), + )?; + // Assign r to the next 2 columns + let _r_split_cell_1 = region.assign_advice( + || "", + self.config.columns[offset + 5], + row_offset, + || r_split.map(|(p1, _)| F::from(p1 as u64)), + )?; + let _r_split_cell_2 = region.assign_advice( + || "", + self.config.columns[offset + 6], + row_offset, + || r_split.map(|(_, p0)| F::from(p0 as u64)), + )?; + + div_out.push(div_cell); + } + + Ok(div_out) + } + + fn forward( + &self, + mut layouter: impl Layouter, + vec_inputs: &Vec>>, + single_inputs: &Vec<&AssignedCell>, + ) -> Result>, Error> { + let mut inps = vec_inputs[0].clone(); + let initial_len = inps.len(); + + // Needed to pad + let default = &single_inputs[0]; + while inps.len() % self.num_inputs_per_row() != 0 { + inps.push(&default); + } + + let res = self.op_aligned_rows( + layouter.namespace(|| "var_div_big"), + &vec![inps], + single_inputs, + )?; + Ok(res[..initial_len].to_vec()) + } +} diff --git a/mnist_zkml/src/gadgets/var_div_big3.rs b/mnist_zkml/src/gadgets/var_div_big3.rs new file mode 100644 index 0000000..90c6651 --- /dev/null +++ b/mnist_zkml/src/gadgets/var_div_big3.rs @@ -0,0 +1,302 @@ +use std::{marker::PhantomData, rc::Rc}; + +use halo2_proofs::{ + circuit::{AssignedCell, Layouter, Region}, + halo2curves::ff::PrimeField, + plonk::{ConstraintSystem, Error, Expression}, + poly::Rotation, +}; +use rounded_div::RoundedDiv; + +use super::gadget::{convert_to_u128, Gadget, GadgetConfig, GadgetType}; + +pub struct VarDivRoundBig3Chip { + config: Rc, + _marker: PhantomData, +} + +impl VarDivRoundBig3Chip { + pub fn construct(config: Rc) -> Self { + Self { + config, + _marker: PhantomData, + } + } + + pub fn num_cols_per_op() -> usize { + 9 + } + + pub fn configure(meta: &mut ConstraintSystem, gadget_config: GadgetConfig) -> GadgetConfig { + let columns = gadget_config.columns; + let selector = meta.complex_selector(); + let two = Expression::Constant(F::from(2)); + let range = Expression::Constant(F::from(gadget_config.num_rows as u64)); + let range_sq = range.clone() * range.clone(); + + let tables = gadget_config.tables; + let lookup = tables.get(&GadgetType::InputLookup).unwrap()[0]; + + // a | c | r | (2b - r)_2 | (2 b - r)_1 | (2 b - r)_0 | r_2 | r_1 | r_0 | ... | b + // a / b = c + meta.create_gate("var_div_big3_arithm", |meta| { + let s = meta.query_selector(selector); + let mut constraints = vec![]; + + let b = meta.query_advice(columns[columns.len() - 1], Rotation::cur()); + for i in 0..(columns.len() - 1) / Self::num_cols_per_op() { + let offset = i * Self::num_cols_per_op(); + // Constrain that (2 * a + b) = (2 * b) * c + r + let a = meta.query_advice(columns[offset], Rotation::cur()); + let c = meta.query_advice(columns[offset + 1], Rotation::cur()); + let r = meta.query_advice(columns[offset + 2], Rotation::cur()); + + let lhs = a.clone() * two.clone() + b.clone(); + let rhs = b.clone() * two.clone() * c + r.clone(); + constraints.push(s.clone() * (lhs - rhs)); + + // Constrain that (2 * b - r) = br1 * max_val + br0 + let br2 = meta.query_advice(columns[offset + 3], Rotation::cur()); + let br1 = meta.query_advice(columns[offset + 4], Rotation::cur()); + let br0 = meta.query_advice(columns[offset + 5], Rotation::cur()); + let lhs = b.clone() * two.clone() - r.clone(); + let rhs = br2 * range_sq.clone() + br1 * range.clone() + br0; + constraints.push(s.clone() * (lhs - rhs)); + + // Constrains that r = r1 * max_val + r0 + let r2 = meta.query_advice(columns[offset + 6], Rotation::cur()); + let r1 = meta.query_advice(columns[offset + 7], Rotation::cur()); + let r0 = meta.query_advice(columns[offset + 8], Rotation::cur()); + let lhs = r.clone(); + let rhs = r2 * range_sq.clone() + r1 * range.clone() + r0; + constraints.push(s.clone() * (lhs - rhs)); + } + + constraints + }); + + // For var div big, we assume that a, b > 0 and are outputs of the previous layer + // r must be constrained to be in [0, b) + for i in 0..(columns.len() - 1) / Self::num_cols_per_op() { + let offset = i * Self::num_cols_per_op(); + + // (2 * b - r)_{1, 0} \in [0, 2^N) + meta.lookup("var div big br2", |meta| { + let s = meta.query_selector(selector); + let br2 = meta.query_advice(columns[offset + 3], Rotation::cur()); + vec![(s * br2, lookup)] + }); + meta.lookup("var div big br1", |meta| { + let s = meta.query_selector(selector); + let br1 = meta.query_advice(columns[offset + 4], Rotation::cur()); + vec![(s * br1, lookup)] + }); + meta.lookup("var div big br0", |meta| { + let s = meta.query_selector(selector); + let br0 = meta.query_advice(columns[offset + 5], Rotation::cur()); + vec![(s * br0, lookup)] + }); + // r_{1, 0} \in [0, 2^N) + meta.lookup("var div big r2", |meta| { + let s = meta.query_selector(selector); + let r2 = meta.query_advice(columns[offset + 6], Rotation::cur()); + vec![(s * r2, lookup)] + }); + meta.lookup("var div big r1", |meta| { + let s = meta.query_selector(selector); + let r1 = meta.query_advice(columns[offset + 7], Rotation::cur()); + vec![(s * r1, lookup)] + }); + meta.lookup("var div big r0", |meta| { + let s = meta.query_selector(selector); + let r0 = meta.query_advice(columns[offset + 8], Rotation::cur()); + vec![(s * r0, lookup)] + }); + } + + let mut selectors = gadget_config.selectors; + selectors.insert(GadgetType::VarDivRoundBig3, vec![selector]); + + GadgetConfig { + columns, + tables, + selectors, + ..gadget_config + } + } +} + +impl Gadget for VarDivRoundBig3Chip { + fn name(&self) -> String { + "VarDivBig3RoundChip".to_string() + } + + fn num_cols_per_op(&self) -> usize { + Self::num_cols_per_op() + } + + fn num_inputs_per_row(&self) -> usize { + (self.config.columns.len() - 1) / self.num_cols_per_op() + } + + fn num_outputs_per_row(&self) -> usize { + self.num_inputs_per_row() + } + + fn op_row_region( + &self, + region: &mut Region, + row_offset: usize, + vec_inputs: &Vec>>, + single_inputs: &Vec<&AssignedCell>, + ) -> Result>, Error> { + let a_vec = &vec_inputs[0]; + // let zero = single_inputs[0].clone(); + let b = &single_inputs[1]; + + let c_shift_base = (-(1_i64 << 62)) as i128; + let num_rows = self.config.num_rows as i128; + + if self.config.use_selectors { + let selector = self + .config + .selectors + .get(&GadgetType::VarDivRoundBig3) + .unwrap()[0]; + selector.enable(region, row_offset)?; + } + + b.copy_advice( + || "", + region, + self.config.columns[self.config.columns.len() - 1], + row_offset, + )?; + + let mut div_out = vec![]; + for (i, a) in a_vec.iter().enumerate() { + let offset = i * self.num_cols_per_op(); + a.copy_advice(|| "", region, self.config.columns[offset], row_offset) + .unwrap(); + + let div_mod = a.value().zip(b.value()).map(|(a, b)| { + let b = convert_to_u128(b); + let c_shift = (-c_shift_base) as u128 / b * b; + let div_inp_min_val_pos = F::from(c_shift as u64); + + let a_pos = *a + div_inp_min_val_pos; + let a = convert_to_u128(&a_pos); + // c = (2 * a + b) / (2 * b) + let c_pos = a.rounded_div(b); + let c = c_pos as i128 - (c_shift / b) as i128; + + // r = (2 * a + b) % (2 * b) + let rem_floor = (a as i128) - (c_pos * b) as i128; + let r = 2 * rem_floor + (b as i128); + (c, r) + }); + + let br_split = div_mod.zip(b.value()).map(|((_, r), b)| { + let b = convert_to_u128(b) as i128; + let val = 2 * b - r; + let p2 = val / (num_rows * num_rows); + let p1 = (val % (num_rows * num_rows)) / num_rows; + let p0 = val % num_rows; + // val = p2 * max_val^2 + p1 * max_val + p0 + (p2, p1, p0) + }); + + let r_split = div_mod.map(|(_, r)| { + let p2 = r / (num_rows * num_rows); + let p1 = (r % (num_rows * num_rows)) / num_rows; + let p0 = r % num_rows; + // val = p1 * max_val + p0 + (p2, p1, p0) + }); + + let div_cell = region.assign_advice( + || "", + self.config.columns[offset + 1], + row_offset, + || { + div_mod.map(|(c, _)| { + let offset = F::from(-c_shift_base as u64); + let c = F::from((c - c_shift_base) as u64); + c - offset + }) + }, + )?; + let _mod_cell = region.assign_advice( + || "", + self.config.columns[offset + 2], + row_offset, + || div_mod.map(|(_, r)| F::from(r as u64)), + )?; + // Assign 2 * b - r to the next 3 columns + let _br_split_cell_2 = region.assign_advice( + || "", + self.config.columns[offset + 3], + row_offset, + || br_split.map(|(p2, _, _)| F::from(p2 as u64)), + )?; + let _br_split_cell_1 = region.assign_advice( + || "", + self.config.columns[offset + 4], + row_offset, + || br_split.map(|(_, p1, _)| F::from(p1 as u64)), + )?; + let _br_split_cell_0 = region.assign_advice( + || "", + self.config.columns[offset + 5], + row_offset, + || br_split.map(|(_, _, p0)| F::from(p0 as u64)), + )?; + // Assign r to the next 3 columns + let _r_split_cell_2 = region.assign_advice( + || "", + self.config.columns[offset + 6], + row_offset, + || r_split.map(|(p2, _, _)| F::from(p2 as u64)), + )?; + let _r_split_cell_1 = region.assign_advice( + || "", + self.config.columns[offset + 7], + row_offset, + || r_split.map(|(_, p1, _)| F::from(p1 as u64)), + )?; + let _r_split_cell_0 = region.assign_advice( + || "", + self.config.columns[offset + 8], + row_offset, + || r_split.map(|(_, _, p0)| F::from(p0 as u64)), + )?; + + div_out.push(div_cell); + } + + Ok(div_out) + } + + fn forward( + &self, + mut layouter: impl Layouter, + vec_inputs: &Vec>>, + single_inputs: &Vec<&AssignedCell>, + ) -> Result>, Error> { + let mut inps = vec_inputs[0].clone(); + let initial_len = inps.len(); + + // Needed to pad + let default = &single_inputs[0]; + while inps.len() % self.num_inputs_per_row() != 0 { + inps.push(&default); + } + + let res = self.op_aligned_rows( + layouter.namespace(|| "var_div_big3"), + &vec![inps], + single_inputs, + )?; + Ok(res[..initial_len].to_vec()) + } +} diff --git a/mnist_zkml/src/layers.rs b/mnist_zkml/src/layers.rs new file mode 100644 index 0000000..49c3eff --- /dev/null +++ b/mnist_zkml/src/layers.rs @@ -0,0 +1,30 @@ +// Generics +pub mod averager; + +pub mod arithmetic; +pub mod shape; + +// Concrete implementations +pub mod avg_pool_2d; +pub mod batch_mat_mul; +pub mod conv2d; +pub mod div_fixed; +pub mod fully_connected; +pub mod logistic; +pub mod max_pool_2d; +pub mod mean; +pub mod noop; +pub mod pow; +pub mod rsqrt; +pub mod softmax; +pub mod sqrt; +pub mod square; +pub mod squared_diff; +pub mod tanh; +pub mod update; + +// Special: dag +pub mod dag; + +// Special: layer +pub mod layer; diff --git a/mnist_zkml/src/layers/arithmetic.rs b/mnist_zkml/src/layers/arithmetic.rs new file mode 100644 index 0000000..61073af --- /dev/null +++ b/mnist_zkml/src/layers/arithmetic.rs @@ -0,0 +1,55 @@ +use std::{collections::HashMap, rc::Rc}; + +use halo2_proofs::{ + circuit::{AssignedCell, Layouter}, + halo2curves::ff::PrimeField, + plonk::Error, +}; + +use crate::{gadgets::gadget::GadgetConfig, utils::helpers::broadcast}; + +use super::layer::{AssignedTensor, CellRc}; + +pub mod add; +pub mod div_var; +pub mod mul; +pub mod sub; + +pub trait Arithmetic { + fn gadget_forward( + &self, + layouter: impl Layouter, + vec_inputs: &Vec>>, + constants: &Vec<&AssignedCell>, + gadget_config: Rc, + ) -> Result>, Error>; + + fn arithmetic_forward( + &self, + mut layouter: impl Layouter, + tensors: &Vec>, + constants: &HashMap>, + gadget_config: Rc, + ) -> Result<(Vec>, Vec), Error> { + assert_eq!(tensors.len(), 2); + // println!("tensors: {:?} {:?}", tensors[0].shape(), tensors[1].shape()); + let (inp1, inp2) = broadcast(&tensors[0], &tensors[1]); + let out_shape = inp1.shape().clone(); + assert_eq!(inp1.shape(), inp2.shape()); + + let zero = constants.get(&0).unwrap().as_ref(); + + let inp1_vec = inp1.iter().map(|x| x.as_ref()).collect::>(); + let inp2_vec = inp2.iter().map(|x| x.as_ref()).collect::>(); + let vec_inputs = vec![inp1_vec, inp2_vec]; + let constants = vec![zero]; + let out = self.gadget_forward( + layouter.namespace(|| ""), + &vec_inputs, + &constants, + gadget_config.clone(), + )?; + let out = out.into_iter().map(|x| Rc::new(x)).collect::>(); + Ok((out, out_shape.to_vec())) + } +} diff --git a/mnist_zkml/src/layers/arithmetic/add.rs b/mnist_zkml/src/layers/arithmetic/add.rs new file mode 100644 index 0000000..9460b23 --- /dev/null +++ b/mnist_zkml/src/layers/arithmetic/add.rs @@ -0,0 +1,106 @@ +use std::{collections::HashMap, rc::Rc, vec}; + +use halo2_proofs::{ + circuit::{AssignedCell, Layouter}, + halo2curves::ff::PrimeField, + plonk::Error, +}; +use ndarray::{Array, IxDyn}; + +use crate::{ + gadgets::{ + add_pairs::AddPairsChip, + gadget::{Gadget, GadgetConfig, GadgetType}, + nonlinear::relu::ReluChip, + }, + layers::layer::{ActivationType, AssignedTensor, CellRc, GadgetConsumer}, +}; + +use super::{ + super::layer::{Layer, LayerConfig}, + Arithmetic, +}; + +#[derive(Clone, Debug)] +pub struct AddChip {} + +impl AddChip { + fn get_activation(&self, layer_params: &Vec) -> ActivationType { + let activation = layer_params[0]; + match activation { + 0 => ActivationType::None, + 1 => ActivationType::Relu, + _ => panic!("Unsupported activation type for add"), + } + } +} + +impl Arithmetic for AddChip { + fn gadget_forward( + &self, + mut layouter: impl Layouter, + vec_inputs: &Vec>>, + constants: &Vec<&AssignedCell>, + gadget_config: Rc, + ) -> Result>, Error> { + let add_pairs_chip = AddPairsChip::::construct(gadget_config); + let out = add_pairs_chip.forward(layouter.namespace(|| "add chip"), &vec_inputs, constants)?; + Ok(out) + } +} + +impl Layer for AddChip { + fn forward( + &self, + mut layouter: impl Layouter, + tensors: &Vec>, + constants: &HashMap>, + gadget_config: Rc, + layer_config: &LayerConfig, + ) -> Result>, Error> { + let activation = self.get_activation(&layer_config.layer_params); + + // Do the addition + let (out, out_shape) = self.arithmetic_forward( + layouter.namespace(|| ""), + tensors, + constants, + gadget_config.clone(), + )?; + + // Do the fused activation + let out = if activation == ActivationType::Relu { + let zero = constants.get(&0).unwrap(); + let single_inps = vec![zero.as_ref()]; + + let out = out.iter().map(|x| x.as_ref()).collect::>(); + + let relu_chip = ReluChip::::construct(gadget_config); + let out = relu_chip.forward(layouter.namespace(|| "relu"), &vec![out], &single_inps)?; + let out = out.into_iter().map(|x| Rc::new(x)).collect::>(); + out + } else if activation == ActivationType::None { + out + } else { + panic!("Unsupported activation type for add"); + }; + + let out = Array::from_shape_vec(IxDyn(out_shape.as_slice()), out).unwrap(); + + Ok(vec![out]) + } +} + +impl GadgetConsumer for AddChip { + fn used_gadgets(&self, layer_params: Vec) -> Vec { + let activation = self.get_activation(&layer_params); + let mut outp = vec![GadgetType::AddPairs]; + + match activation { + ActivationType::Relu => outp.push(GadgetType::Relu), + ActivationType::None => (), + _ => panic!("Unsupported activation type for add"), + } + outp + } +} diff --git a/mnist_zkml/src/layers/arithmetic/div_var.rs b/mnist_zkml/src/layers/arithmetic/div_var.rs new file mode 100644 index 0000000..1b38643 --- /dev/null +++ b/mnist_zkml/src/layers/arithmetic/div_var.rs @@ -0,0 +1,93 @@ +use std::{collections::HashMap, rc::Rc, vec}; + +use halo2_proofs::{ + circuit::{AssignedCell, Layouter}, + halo2curves::ff::PrimeField, + plonk::Error, +}; +use ndarray::{Array, IxDyn}; + +use crate::{ + gadgets::{ + gadget::{Gadget, GadgetConfig, GadgetType}, + mul_pairs::MulPairsChip, + var_div::VarDivRoundChip, + }, + layers::layer::{AssignedTensor, CellRc, GadgetConsumer, Layer}, +}; + +use super::Arithmetic; + +pub struct DivVarChip {} + +// TODO: hack. Used for multiplying by the scale factor +impl Arithmetic for DivVarChip { + fn gadget_forward( + &self, + mut layouter: impl Layouter, + vec_inputs: &Vec>>, + constants: &Vec<&AssignedCell>, + gadget_config: Rc, + ) -> Result>, Error> { + let mul_pairs_chip = MulPairsChip::::construct(gadget_config.clone()); + + let out = mul_pairs_chip.forward( + layouter.namespace(|| "mul pairs chip"), + &vec_inputs, + constants, + )?; + Ok(out) + } +} + +impl Layer for DivVarChip { + fn forward( + &self, + mut layouter: impl Layouter, + tensors: &Vec>, + constants: &HashMap>, + gadget_config: Rc, + _layer_config: &crate::layers::layer::LayerConfig, + ) -> Result>, Error> { + assert_eq!(tensors.len(), 2); + // TODO: We only support dividing by a single number for now + assert_eq!(tensors[1].shape().len(), 1); + assert_eq!(tensors[1].shape()[0], 1); + + let sf = constants + .get(&(gadget_config.scale_factor as i64)) + .unwrap() + .as_ref(); + + let sf_tensor = Array::from_shape_vec(IxDyn(&[1]), vec![Rc::new(sf.clone())]).unwrap(); + + // out = inp * SF + let (out, out_shape) = self.arithmetic_forward( + layouter.namespace(|| ""), + &vec![tensors[0].clone(), sf_tensor], + constants, + gadget_config.clone(), + )?; + + let var_div_chip = VarDivRoundChip::::construct(gadget_config.clone()); + let div = tensors[1].iter().next().unwrap().as_ref(); + let zero = constants.get(&0).unwrap().as_ref(); + let single_inputs = vec![zero, div]; + let out = out.iter().map(|x| x.as_ref()).collect::>(); + let out = var_div_chip.forward(layouter.namespace(|| "mul div"), &vec![out], &single_inputs)?; + + let out = out.into_iter().map(|x| Rc::new(x)).collect::>(); + let out = Array::from_shape_vec(IxDyn(out_shape.as_slice()), out).unwrap(); + Ok(vec![out]) + } +} + +impl GadgetConsumer for DivVarChip { + fn used_gadgets(&self, _layer_params: Vec) -> Vec { + vec![ + GadgetType::MulPairs, + GadgetType::VarDivRound, + GadgetType::InputLookup, + ] + } +} diff --git a/mnist_zkml/src/layers/arithmetic/mul.rs b/mnist_zkml/src/layers/arithmetic/mul.rs new file mode 100644 index 0000000..e23e6b4 --- /dev/null +++ b/mnist_zkml/src/layers/arithmetic/mul.rs @@ -0,0 +1,87 @@ +use std::{collections::HashMap, rc::Rc, vec}; + +use halo2_proofs::{ + circuit::{AssignedCell, Layouter}, + halo2curves::ff::PrimeField, + plonk::Error, +}; +use ndarray::{Array, IxDyn}; + +use crate::{ + gadgets::{ + gadget::{Gadget, GadgetConfig, GadgetType}, + mul_pairs::MulPairsChip, + var_div::VarDivRoundChip, + }, + layers::layer::{AssignedTensor, CellRc, GadgetConsumer}, +}; + +use super::{ + super::layer::{Layer, LayerConfig}, + Arithmetic, +}; + +#[derive(Clone, Debug)] +pub struct MulChip {} + +impl Arithmetic for MulChip { + fn gadget_forward( + &self, + mut layouter: impl Layouter, + vec_inputs: &Vec>>, + constants: &Vec<&AssignedCell>, + gadget_config: Rc, + ) -> Result>, Error> { + let mul_pairs_chip = MulPairsChip::::construct(gadget_config.clone()); + + let out = mul_pairs_chip.forward( + layouter.namespace(|| "mul pairs chip"), + &vec_inputs, + constants, + )?; + Ok(out) + } +} + +// FIXME: move this + add to an arithmetic layer +impl Layer for MulChip { + fn forward( + &self, + mut layouter: impl Layouter, + tensors: &Vec>, + constants: &HashMap>, + gadget_config: Rc, + _layer_config: &LayerConfig, + ) -> Result>, Error> { + let (out, out_shape) = self.arithmetic_forward( + layouter.namespace(|| ""), + tensors, + constants, + gadget_config.clone(), + )?; + + let var_div_chip = VarDivRoundChip::::construct(gadget_config.clone()); + let div = constants + .get(&(gadget_config.scale_factor as i64)) + .unwrap() + .as_ref(); + let zero = constants.get(&0).unwrap().as_ref(); + let single_inputs = vec![zero, div]; + let out = out.iter().map(|x| x.as_ref()).collect::>(); + let out = var_div_chip.forward(layouter.namespace(|| "mul div"), &vec![out], &single_inputs)?; + + let out = out.into_iter().map(|x| Rc::new(x)).collect::>(); + let out = Array::from_shape_vec(IxDyn(out_shape.as_slice()), out).unwrap(); + Ok(vec![out]) + } +} + +impl GadgetConsumer for MulChip { + fn used_gadgets(&self, _layer_params: Vec) -> Vec { + vec![ + GadgetType::MulPairs, + GadgetType::VarDivRound, + GadgetType::InputLookup, + ] + } +} diff --git a/mnist_zkml/src/layers/arithmetic/sub.rs b/mnist_zkml/src/layers/arithmetic/sub.rs new file mode 100644 index 0000000..4299039 --- /dev/null +++ b/mnist_zkml/src/layers/arithmetic/sub.rs @@ -0,0 +1,65 @@ +use std::{collections::HashMap, rc::Rc, vec}; + +use halo2_proofs::{ + circuit::{AssignedCell, Layouter}, + halo2curves::ff::PrimeField, + plonk::Error, +}; +use ndarray::{Array, IxDyn}; + +use crate::{ + gadgets::{ + gadget::{Gadget, GadgetConfig, GadgetType}, + sub_pairs::SubPairsChip, + }, + layers::layer::{AssignedTensor, CellRc, GadgetConsumer}, +}; + +use super::{ + super::layer::{Layer, LayerConfig}, + Arithmetic, +}; + +#[derive(Clone, Debug)] +pub struct SubChip {} + +impl Arithmetic for SubChip { + fn gadget_forward( + &self, + mut layouter: impl Layouter, + vec_inputs: &Vec>>, + constants: &Vec<&AssignedCell>, + gadget_config: Rc, + ) -> Result>, Error> { + let sub_pairs_chip = SubPairsChip::::construct(gadget_config); + let out = sub_pairs_chip.forward(layouter.namespace(|| "sub chip"), &vec_inputs, constants)?; + Ok(out) + } +} + +impl Layer for SubChip { + fn forward( + &self, + mut layouter: impl Layouter, + tensors: &Vec>, + constants: &HashMap>, + gadget_config: Rc, + _layer_config: &LayerConfig, + ) -> Result>, Error> { + let (out, out_shape) = self.arithmetic_forward( + layouter.namespace(|| ""), + tensors, + constants, + gadget_config.clone(), + )?; + let out = Array::from_shape_vec(IxDyn(out_shape.as_slice()), out).unwrap(); + + Ok(vec![out]) + } +} + +impl GadgetConsumer for SubChip { + fn used_gadgets(&self, _layer_params: Vec) -> Vec { + vec![GadgetType::SubPairs] + } +} diff --git a/mnist_zkml/src/layers/averager.rs b/mnist_zkml/src/layers/averager.rs new file mode 100644 index 0000000..3445fb4 --- /dev/null +++ b/mnist_zkml/src/layers/averager.rs @@ -0,0 +1,72 @@ +use std::{collections::HashMap, rc::Rc}; + +use halo2_proofs::{ + circuit::{AssignedCell, Layouter}, + halo2curves::ff::PrimeField, + plonk::Error, +}; + +use crate::gadgets::gadget::Gadget; +use crate::gadgets::{adder::AdderChip, gadget::GadgetConfig, var_div::VarDivRoundChip}; + +use super::layer::{AssignedTensor, CellRc, LayerConfig}; + +pub trait Averager { + fn splat(&self, input: &AssignedTensor, layer_config: &LayerConfig) -> Vec>>; + + fn get_div_val( + &self, + layouter: impl Layouter, + tensors: &Vec>, + gadget_config: Rc, + layer_config: &LayerConfig, + ) -> Result, Error>; + + fn avg_forward( + &self, + mut layouter: impl Layouter, + tensors: &Vec>, + constants: &HashMap>, + gadget_config: Rc, + layer_config: &LayerConfig, + ) -> Result>, Error> { + // Due to Mean BS + // assert_eq!(tensors.len(), 1); + let zero = constants.get(&0).unwrap().as_ref(); + + let inp = &tensors[0]; + let splat_inp = self.splat(inp, layer_config); + + let adder_chip = AdderChip::::construct(gadget_config.clone()); + let single_inputs = vec![zero]; + let mut added = vec![]; + for i in 0..splat_inp.len() { + let tmp = splat_inp[i].iter().map(|x| x.as_ref()).collect::>(); + let tmp = adder_chip.forward( + layouter.namespace(|| format!("average {}", i)), + &vec![tmp], + &single_inputs, + )?; + added.push(tmp[0].clone()); + } + + let div = self.get_div_val( + layouter.namespace(|| "average div"), + tensors, + gadget_config.clone(), + layer_config, + )?; + let var_div_chip = VarDivRoundChip::::construct(gadget_config.clone()); + + let single_inputs = vec![zero, &div]; + let added = added.iter().map(|x| x).collect::>(); + let dived = var_div_chip.forward( + layouter.namespace(|| "average div"), + &vec![added], + &single_inputs, + )?; + let dived = dived.into_iter().map(|x| Rc::new(x)).collect::>(); + + Ok(dived) + } +} diff --git a/mnist_zkml/src/layers/avg_pool_2d.rs b/mnist_zkml/src/layers/avg_pool_2d.rs new file mode 100644 index 0000000..05c5178 --- /dev/null +++ b/mnist_zkml/src/layers/avg_pool_2d.rs @@ -0,0 +1,95 @@ +use std::{collections::HashMap, rc::Rc}; + +use halo2_proofs::{ + circuit::{AssignedCell, Layouter, Value}, + halo2curves::ff::PrimeField, + plonk::Error, +}; +use ndarray::{Array, IxDyn}; + +use crate::{ + gadgets::gadget::{GadgetConfig, GadgetType}, + layers::max_pool_2d::MaxPool2DChip, +}; + +use super::{ + averager::Averager, + layer::{AssignedTensor, CellRc, GadgetConsumer, Layer, LayerConfig}, +}; + +pub struct AvgPool2DChip {} + +impl Averager for AvgPool2DChip { + fn splat(&self, input: &AssignedTensor, layer_config: &LayerConfig) -> Vec>> { + assert_eq!(input.shape().len(), 4); + // Don't support batch size > 1 yet + assert_eq!(input.shape()[0], 1); + + // TODO: refactor this + MaxPool2DChip::splat(input, layer_config).unwrap() + } + + fn get_div_val( + &self, + mut layouter: impl Layouter, + _tensors: &Vec>, + gadget_config: Rc, + layer_config: &LayerConfig, + ) -> Result, Error> { + // FIXME: this needs to be revealed + let div = layer_config.layer_params[0] * layer_config.layer_params[1]; + let div = F::from(div as u64); + let div = layouter + .assign_region( + || "avg pool 2d div", + |mut region| { + let div = region + .assign_advice( + || "avg pool 2d div", + gadget_config.columns[0], + 0, + || Value::known(div), + ) + .unwrap(); + Ok(div) + }, + ) + .unwrap(); + + Ok(div) + } +} + +impl Layer for AvgPool2DChip { + fn forward( + &self, + layouter: impl Layouter, + tensors: &Vec>, + constants: &HashMap>, + gadget_config: Rc, + layer_config: &LayerConfig, + ) -> Result>, Error> { + let dived = self + .avg_forward(layouter, tensors, constants, gadget_config, layer_config) + .unwrap(); + + let inp = &tensors[0]; + // TODO: refactor this + let out_xy = MaxPool2DChip::shape(inp, layer_config); + let out_shape = vec![1, out_xy.0, out_xy.1, inp.shape()[3]]; + println!("out_shape: {:?}", out_shape); + + let out = Array::from_shape_vec(IxDyn(&out_shape), dived).unwrap(); + Ok(vec![out]) + } +} + +impl GadgetConsumer for AvgPool2DChip { + fn used_gadgets(&self, _layer_params: Vec) -> Vec { + vec![ + GadgetType::Adder, + GadgetType::VarDivRound, + GadgetType::InputLookup, + ] + } +} diff --git a/mnist_zkml/src/layers/batch_mat_mul.rs b/mnist_zkml/src/layers/batch_mat_mul.rs new file mode 100644 index 0000000..94ec288 --- /dev/null +++ b/mnist_zkml/src/layers/batch_mat_mul.rs @@ -0,0 +1,95 @@ +use std::{collections::HashMap, marker::PhantomData, rc::Rc}; + +use halo2_proofs::{circuit::Layouter, halo2curves::ff::PrimeField, plonk::Error}; +use ndarray::{Array, Axis, IxDyn}; + +use crate::{ + gadgets::gadget::{GadgetConfig, GadgetType}, + layers::fully_connected::FullyConnectedConfig, +}; + +use super::{ + fully_connected::FullyConnectedChip, + layer::{AssignedTensor, CellRc, GadgetConsumer, Layer, LayerConfig}, +}; + +pub struct BatchMatMulChip {} + +impl Layer for BatchMatMulChip { + fn forward( + &self, + mut layouter: impl Layouter, + tensors: &Vec>, + constants: &HashMap>, + gadget_config: Rc, + layer_config: &LayerConfig, + ) -> Result>, Error> { + let inp1 = &tensors[0]; + let inp2 = &tensors[1]; + println!("inp1: {:?}", inp1.shape()); + println!("inp2: {:?}", inp2.shape()); + + assert_eq!(inp1.ndim(), 3); + assert_eq!(inp2.ndim(), 3); + assert_eq!(inp1.shape()[0], inp2.shape()[0]); + + let adj_y = layer_config.layer_params[1] == 1; + if adj_y { + assert_eq!(inp1.shape()[2], inp2.shape()[2]); + } else { + assert_eq!(inp1.shape()[2], inp2.shape()[1]); + } + + let out_shape = if adj_y { + vec![inp1.shape()[0], inp1.shape()[1], inp2.shape()[1]] + } else { + vec![inp1.shape()[0], inp1.shape()[1], inp2.shape()[2]] + }; + + let fc_chip = FullyConnectedChip:: { + _marker: PhantomData, + config: FullyConnectedConfig::construct(true), + }; + + let mut outp: Vec> = vec![]; + for i in 0..inp1.shape()[0] { + let inp1_slice = inp1.index_axis(Axis(0), i).to_owned(); + // Due to tensorflow BS, transpose the "weights" + let inp2_slice = if adj_y { + inp2.index_axis(Axis(0), i).to_owned() + } else { + inp2.index_axis(Axis(0), i).t().to_owned() + }; + println!("inp1_slice: {:?}", inp1_slice.shape()); + println!("inp2_slice: {:?}", inp2_slice.shape()); + // Batch MM doesn't have a fused activation, so insert it here + // TODO: consider putting this in the converter? + let tmp_config = LayerConfig { + layer_params: vec![0], + ..layer_config.clone() + }; + let outp_slice = fc_chip.forward( + layouter.namespace(|| ""), + &vec![inp1_slice, inp2_slice], + constants, + gadget_config.clone(), + &tmp_config, + )?; + outp.extend(outp_slice[0].iter().map(|x| x.clone()).collect::>()); + } + + let outp = Array::from_shape_vec(IxDyn(out_shape.as_slice()), outp).unwrap(); + Ok(vec![outp]) + } +} + +impl GadgetConsumer for BatchMatMulChip { + fn used_gadgets(&self, _layer_params: Vec) -> Vec { + vec![ + GadgetType::Adder, + GadgetType::DotProduct, + GadgetType::VarDivRound, + GadgetType::InputLookup, + ] + } +} diff --git a/mnist_zkml/src/layers/conv2d.rs b/mnist_zkml/src/layers/conv2d.rs new file mode 100644 index 0000000..c8f7286 --- /dev/null +++ b/mnist_zkml/src/layers/conv2d.rs @@ -0,0 +1,452 @@ +// TODO: Speed up Depthwise operations with Freivald's algorithm + +use std::{collections::HashMap, marker::PhantomData, rc::Rc}; + +use halo2_proofs::{ + circuit::{AssignedCell, Layouter}, + halo2curves::ff::PrimeField, + plonk::Error, +}; +use ndarray::{Array, IxDyn}; + +use crate::{ + gadgets::{ + bias_div_round_relu6::BiasDivRoundRelu6Chip, + dot_prod::DotProductChip, + gadget::{Gadget, GadgetConfig, GadgetType}, + nonlinear::relu::ReluChip, + }, + layers::{ + fully_connected::{FullyConnectedChip, FullyConnectedConfig}, + shape::pad::pad, + }, +}; + +use super::layer::{ActivationType, AssignedTensor, GadgetConsumer, Layer, LayerConfig}; + +#[derive(Default, Clone, Copy, Eq, PartialEq)] +pub enum PaddingEnum { + #[default] + Same, + Valid, +} + +#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)] +pub enum ConvLayerEnum { + #[default] + Conv2D, + DepthwiseConv2D, +} + +pub struct Conv2DConfig { + pub conv_type: ConvLayerEnum, + pub padding: PaddingEnum, + pub activation: ActivationType, + pub stride: (usize, usize), +} + +pub struct Conv2DChip { + pub config: LayerConfig, + pub _marker: PhantomData, +} + +impl Conv2DChip { + // TODO: this is horrible. What's the best way to fix this? + pub fn param_vec_to_config(layer_params: Vec) -> Conv2DConfig { + let conv_type = match layer_params[0] { + 0 => ConvLayerEnum::Conv2D, + 1 => ConvLayerEnum::DepthwiseConv2D, + _ => panic!("Invalid conv type"), + }; + let padding = match layer_params[1] { + 0 => PaddingEnum::Same, + 1 => PaddingEnum::Valid, + _ => panic!("Invalid padding"), + }; + let activation = match layer_params[2] { + 0 => ActivationType::None, + 1 => ActivationType::Relu, + 3 => ActivationType::Relu6, + _ => panic!("Invalid activation type"), + }; + let stride = (layer_params[3] as usize, layer_params[4] as usize); + Conv2DConfig { + conv_type, + padding, + activation, + stride, + } + } + + pub fn get_padding( + h: usize, + w: usize, + si: usize, + sj: usize, + ci: usize, + cj: usize, + ) -> ((usize, usize), (usize, usize)) { + let ph = if h % si == 0 { + (ci as i64 - sj as i64).max(0) + } else { + (ci as i64 - (h % si) as i64).max(0) + } as usize; + let pw = if w % sj == 0 { + (cj as i64 - sj as i64).max(0) + } else { + (cj as i64 - (w % sj) as i64).max(0) + } as usize; + ((ph / 2, ph - ph / 2), (pw / 2, pw - pw / 2)) + } + + pub fn out_hw( + h: usize, + w: usize, + si: usize, + sj: usize, + ch: usize, + cw: usize, + padding: PaddingEnum, + ) -> (usize, usize) { + /* + println!( + "H: {}, W: {}, SI: {}, SJ: {}, CH: {}, CW: {}", + h, w, si, sj, ch, cw + ); + */ + // https://iq.opengenus.org/same-and-valid-padding/ + match padding { + PaddingEnum::Same => ((h + si - 1) / si, (w + sj - 1) / sj), + // TODO: the above is probably correct, but we always have valid paddings + // PaddingEnum::Same => (h / si, w / sj), + PaddingEnum::Valid => ((h - ch) / si + 1, (w - cw) / sj + 1), + } + } + + pub fn splat( + &self, + tensors: &Vec, IxDyn>>, + zero: Rc, + ) -> (Vec>>, Vec>>, Vec>) { + // assert_eq!(tensors.len(), 3); + assert!(tensors.len() <= 3); + + let conv_config = &Self::param_vec_to_config(self.config.layer_params.clone()); + + let inp = &tensors[0]; + let weights = &tensors[1]; + let zero_arr = Array::from_elem(IxDyn(&vec![1]), zero.clone()); + let biases = if tensors.len() == 3 { + &tensors[2] + } else { + &zero_arr + }; + + let h: usize = inp.shape()[1]; + let w: usize = inp.shape()[2]; + + let ch: usize = weights.shape()[1]; + let cw: usize = weights.shape()[2]; + + let (si, sj) = conv_config.stride; + + // B, H, W, C + assert_eq!(inp.shape().len(), 4); + + let (ph, pw) = if conv_config.padding == PaddingEnum::Same { + Self::get_padding(h, w, si, sj, ch, cw) + } else { + ((0, 0), (0, 0)) + }; + // println!("Padding: {:?}", (ph, pw)); + let padding = vec![[0, 0], [ph.0, ph.1], [pw.0, pw.1], [0, 0]]; + + let inp_pad = pad(&inp, padding, &zero); + + let (oh, ow) = Self::out_hw(h, w, si, sj, ch, cw, conv_config.padding); + + let mut inp_cells = vec![]; + let mut weights_cells = vec![]; + let mut biases_cells = vec![]; + let mut input_row_idx = 0; + let mut weight_row_idx = 0; + + // (output_channels x inp_channels * C_H * C_W) + for chan_out in 0..weights.shape()[0] { + weights_cells.push(vec![]); + for ci in 0..weights.shape()[1] { + for cj in 0..weights.shape()[2] { + for ck in 0..weights.shape()[3] { + weights_cells[weight_row_idx].push(weights[[chan_out, ci, cj, ck]].clone()); + } + } + } + weight_row_idx += 1; + } + + // (O_H * O_W x inp_channels * C_H * C_W) + for batch in 0..inp.shape()[0] { + for i in 0..oh { + for j in 0..ow { + inp_cells.push(vec![]); + for ci in 0..weights.shape()[1] { + for cj in 0..weights.shape()[2] { + for ck in 0..weights.shape()[3] { + let idx_i = i * si + ci; + let idx_j = j * sj + cj; + inp_cells[input_row_idx].push(inp_pad[[batch, idx_i, idx_j, ck]].clone()); + } + } + } + input_row_idx += 1; + } + } + } + + for _batch in 0..inp.shape()[0] { + for _ in 0..oh { + for _ in 0..ow { + for chan_out in 0..weights.shape()[0] { + if tensors.len() == 3 { + biases_cells.push(biases[chan_out].clone()); + } else { + biases_cells.push(zero.clone()); + } + } + } + } + } + + (inp_cells, weights_cells, biases_cells) + } + + pub fn splat_depthwise( + &self, + tensors: &Vec, IxDyn>>, + zero: Rc, + ) -> (Vec>>, Vec>>, Vec>) { + let input = &tensors[0]; + let weights = &tensors[1]; + let biases = &tensors[2]; + + assert_eq!(tensors.len(), 3); + assert_eq!(input.shape().len(), 4); + assert_eq!(weights.shape().len(), 4); + assert_eq!(input.shape()[0], 1); + + let conv_config = &Self::param_vec_to_config(self.config.layer_params.clone()); + let strides = conv_config.stride; + + let h: usize = input.shape()[1]; + let w: usize = input.shape()[2]; + let ch: usize = weights.shape()[1]; + let cw: usize = weights.shape()[2]; + let (si, sj) = conv_config.stride; + let (oh, ow) = Self::out_hw(h, w, si, sj, ch, cw, conv_config.padding); + + let (ph, pw) = if conv_config.padding == PaddingEnum::Same { + Self::get_padding(h, w, si, sj, ch, cw) + } else { + ((0, 0), (0, 0)) + }; + + let padding = vec![[0, 0], [ph.0, ph.1], [pw.0, pw.1], [0, 0]]; + + let inp_pad = pad(&input, padding, &zero); + + let mut inp_cells = vec![]; + let mut weight_cells = vec![]; + let mut biases_cells = vec![]; + let mut row_idx = 0; + + for i in 0..oh { + for j in 0..ow { + for chan_out in 0..weights.shape()[3] { + inp_cells.push(vec![]); + weight_cells.push(vec![]); + biases_cells.push(biases[[chan_out]].clone()); + + for ci in 0..weights.shape()[1] { + for cj in 0..weights.shape()[2] { + let idx_i = i * strides.0 + ci; + let idx_j = j * strides.1 + cj; + + inp_cells[row_idx].push(inp_pad[[0, idx_i, idx_j, chan_out]].clone()); + weight_cells[row_idx].push(weights[[0, ci, cj, chan_out]].clone()); + } + } + + row_idx += 1; + } + } + } + + (inp_cells, weight_cells, biases_cells) + } +} + +impl Layer for Conv2DChip { + fn forward( + &self, + mut layouter: impl Layouter, + tensors: &Vec>, + constants: &HashMap>>, + gadget_config: Rc, + layer_config: &LayerConfig, + ) -> Result>, Error> { + let conv_config = &Self::param_vec_to_config(self.config.layer_params.clone()); + let zero = constants.get(&0).unwrap(); + + let inp = &tensors[0]; + let weights = &tensors[1]; + + let (oh, ow) = Self::out_hw( + inp.shape()[1], + inp.shape()[2], + conv_config.stride.0, + conv_config.stride.1, + weights.shape()[1], + weights.shape()[2], + conv_config.padding, + ); + let batch_size = inp.shape()[0]; + + let (splat_inp, splat_weights, splat_biases) = match conv_config.conv_type { + ConvLayerEnum::Conv2D => self.splat(tensors, zero.clone()), + ConvLayerEnum::DepthwiseConv2D => self.splat_depthwise(tensors, zero.clone()), + }; + + let outp_flat: Vec> = match conv_config.conv_type { + ConvLayerEnum::Conv2D => { + let fc_chip = FullyConnectedChip:: { + _marker: PhantomData, + config: FullyConnectedConfig::construct(false), + }; + + let conv_size = splat_inp[0].len(); + let flattened_inp: Vec<_> = splat_inp.into_iter().flat_map(|x| x.into_iter()).collect(); + let flattened_weights = splat_weights + .into_iter() + .flat_map(|x| x.into_iter()) + .collect::>(); + + let out_channels = weights.shape()[0]; + let inp_array = + Array::from_shape_vec(IxDyn(&vec![batch_size * oh * ow, conv_size]), flattened_inp) + .unwrap(); + let weights_array = + Array::from_shape_vec(IxDyn(&vec![out_channels, conv_size]), flattened_weights).unwrap(); + + let outp_slice = fc_chip + .forward( + layouter.namespace(|| ""), + &vec![weights_array, inp_array], + constants, + gadget_config.clone(), + layer_config, + ) + .unwrap(); + + let outp_flat = outp_slice[0] + .t() + .into_iter() + .map(|x| (**x).clone()) + .collect::>(); + outp_flat + } + ConvLayerEnum::DepthwiseConv2D => { + // Do the dot products + let dot_prod_chip = DotProductChip::::construct(gadget_config.clone()); + let mut outp_flat = vec![]; + for (inp_vec, weight_vec) in splat_inp.iter().zip(splat_weights.iter()) { + let inp_vec = inp_vec.iter().map(|x| x.as_ref()).collect::>(); + let weight_vec = weight_vec.iter().map(|x| x.as_ref()).collect::>(); + let vec_inputs = vec![inp_vec, weight_vec]; + let constants = vec![zero.as_ref()]; + let outp = dot_prod_chip + .forward(layouter.namespace(|| "dot_prod"), &vec_inputs, &constants) + .unwrap(); + outp_flat.push(outp[0].clone()); + } + // println!("outp_flat: {:?}", outp_flat.len()); + + outp_flat + } + }; + + let mut biases = vec![]; + for bias in splat_biases.iter() { + biases.push(bias.as_ref()); + } + + // Compute the bias + div + relu + let bdr_chip = BiasDivRoundRelu6Chip::::construct(gadget_config.clone()); + let tmp = vec![zero.as_ref()]; + let outp_flat = outp_flat.iter().map(|x| x).collect::>(); + let outp = bdr_chip + .forward( + layouter.namespace(|| "bias_div_relu"), + &vec![outp_flat, biases], + &tmp, + ) + .unwrap(); + + // TODO: this is also horrible. The bdr chip outputs interleaved [(relu'd, div'd), (relu'd, div'd), ...] + // Uninterleave depending on whether or not we're doing the relu + let outp = if conv_config.activation == ActivationType::Relu6 { + outp + .into_iter() + .step_by(2) + .map(|x| Rc::new(x)) + .collect::>() + } else if conv_config.activation == ActivationType::None { + outp + .into_iter() + .skip(1) + .step_by(2) + .map(|x| Rc::new(x)) + .collect::>() + } else if conv_config.activation == ActivationType::Relu { + let dived = outp.iter().skip(1).step_by(2).collect::>(); + let relu_chip = ReluChip::::construct(gadget_config.clone()); + let relu_outp = relu_chip + .forward(layouter.namespace(|| "relu"), &vec![dived], &tmp) + .unwrap(); + let relu_outp = relu_outp + .into_iter() + .map(|x| Rc::new(x)) + .collect::>(); + relu_outp + } else { + panic!("Unsupported activation type"); + }; + + let oc = match conv_config.conv_type { + ConvLayerEnum::Conv2D => weights.shape()[0], + ConvLayerEnum::DepthwiseConv2D => weights.shape()[3], + }; + + let out_shape = vec![batch_size, oh, ow, oc]; + let outp = Array::from_shape_vec(IxDyn(&out_shape), outp).unwrap(); + + Ok(vec![outp]) + } +} + +impl GadgetConsumer for Conv2DChip { + fn used_gadgets(&self, layer_params: Vec) -> Vec { + let conv_config = &Self::param_vec_to_config(layer_params.clone()); + let mut outp = vec![ + GadgetType::Adder, + GadgetType::DotProduct, + GadgetType::InputLookup, + GadgetType::BiasDivRoundRelu6, + ]; + + if conv_config.activation == ActivationType::Relu { + outp.push(GadgetType::Relu); + } + + outp + } +} diff --git a/mnist_zkml/src/layers/dag.rs b/mnist_zkml/src/layers/dag.rs new file mode 100644 index 0000000..b73a51a --- /dev/null +++ b/mnist_zkml/src/layers/dag.rs @@ -0,0 +1,484 @@ +use std::{collections::HashMap, fs::File, io::BufWriter, marker::PhantomData, rc::Rc}; + +use halo2_proofs::{circuit::Layouter, halo2curves::ff::PrimeField, plonk::Error}; + +use crate::{ + gadgets::gadget::{convert_to_u64, GadgetConfig}, + layers::{ + arithmetic::{add::AddChip, div_var::DivVarChip, mul::MulChip, sub::SubChip}, + batch_mat_mul::BatchMatMulChip, + div_fixed::DivFixedChip, + fully_connected::{FullyConnectedChip, FullyConnectedConfig}, + logistic::LogisticChip, + max_pool_2d::MaxPool2DChip, + mean::MeanChip, + noop::NoopChip, + pow::PowChip, + rsqrt::RsqrtChip, + shape::{ + broadcast::BroadcastChip, concatenation::ConcatenationChip, mask_neg_inf::MaskNegInfChip, + pack::PackChip, pad::PadChip, permute::PermuteChip, reshape::ReshapeChip, + resize_nn::ResizeNNChip, rotate::RotateChip, slice::SliceChip, split::SplitChip, + transpose::TransposeChip, + }, + softmax::SoftmaxChip, + sqrt::SqrtChip, + square::SquareChip, + squared_diff::SquaredDiffChip, + tanh::TanhChip, + update::UpdateChip, + }, + utils::helpers::print_assigned_arr, +}; + +use super::{ + avg_pool_2d::AvgPool2DChip, + conv2d::Conv2DChip, + layer::{AssignedTensor, CellRc, GadgetConsumer, Layer, LayerConfig, LayerType}, +}; + +#[derive(Clone, Debug, Default)] +pub struct DAGLayerConfig { + pub ops: Vec, + pub inp_idxes: Vec>, + pub out_idxes: Vec>, + pub final_out_idxes: Vec, +} + +pub struct DAGLayerChip { + dag_config: DAGLayerConfig, + _marker: PhantomData, +} + +impl DAGLayerChip { + pub fn construct(dag_config: DAGLayerConfig) -> Self { + Self { + dag_config, + _marker: PhantomData, + } + } + + // IMPORTANT: Assumes input tensors are in order. Output tensors can be in any order. + pub fn forward( + &self, + mut layouter: impl Layouter, + tensors: &Vec>, + constants: &HashMap>, + gadget_config: Rc, + _layer_config: &LayerConfig, + ) -> Result<(HashMap>, Vec>), Error> { + // Tensor map + let mut tensor_map = HashMap::new(); + for (idx, tensor) in tensors.iter().enumerate() { + tensor_map.insert(idx, tensor.clone()); + } + + // Compute the dag + for (layer_idx, layer_config) in self.dag_config.ops.iter().enumerate() { + let layer_type = &layer_config.layer_type; + let inp_idxes = &self.dag_config.inp_idxes[layer_idx]; + let out_idxes = &self.dag_config.out_idxes[layer_idx]; + println!( + "Processing layer {}, type: {:?}, inp_idxes: {:?}, out_idxes: {:?}, layer_params: {:?}", + layer_idx, layer_type, inp_idxes, out_idxes, layer_config.layer_params + ); + let vec_inps = inp_idxes + .iter() + .map(|idx| tensor_map.get(idx).unwrap().clone()) + .collect::>(); + + let out = match layer_type { + LayerType::Add => { + let add_chip = AddChip {}; + add_chip.forward( + layouter.namespace(|| "dag add"), + &vec_inps, + constants, + gadget_config.clone(), + &layer_config, + )? + } + LayerType::AvgPool2D => { + let avg_pool_2d_chip = AvgPool2DChip {}; + avg_pool_2d_chip.forward( + layouter.namespace(|| "dag avg pool 2d"), + &vec_inps, + constants, + gadget_config.clone(), + &layer_config, + )? + } + LayerType::MaxPool2D => { + let max_pool_2d_chip = MaxPool2DChip { + marker: PhantomData::, + }; + max_pool_2d_chip.forward( + layouter.namespace(|| "dag max pool 2d"), + &vec_inps, + constants, + gadget_config.clone(), + &layer_config, + )? + } + LayerType::BatchMatMul => { + let batch_mat_mul_chip = BatchMatMulChip {}; + batch_mat_mul_chip.forward( + layouter.namespace(|| "dag batch mat mul"), + &vec_inps, + constants, + gadget_config.clone(), + &layer_config, + )? + } + LayerType::Broadcast => { + let broadcast_chip = BroadcastChip {}; + broadcast_chip.forward( + layouter.namespace(|| "dag batch mat mul"), + &vec_inps, + constants, + gadget_config.clone(), + &layer_config, + )? + } + LayerType::Conv2D => { + let conv_2d_chip = Conv2DChip { + config: layer_config.clone(), + _marker: PhantomData, + }; + conv_2d_chip.forward( + layouter.namespace(|| "dag conv 2d"), + &vec_inps, + constants, + gadget_config.clone(), + &layer_config, + )? + } + LayerType::DivFixed => { + let div_fixed_chip = DivFixedChip {}; + div_fixed_chip.forward( + layouter.namespace(|| "dag div"), + &vec_inps, + constants, + gadget_config.clone(), + &layer_config, + )? + } + LayerType::DivVar => { + let div_var_chip = DivVarChip {}; + div_var_chip.forward( + layouter.namespace(|| "dag div"), + &vec_inps, + constants, + gadget_config.clone(), + &layer_config, + )? + } + LayerType::FullyConnected => { + let fc_chip = FullyConnectedChip { + _marker: PhantomData, + config: FullyConnectedConfig::construct(true), + }; + fc_chip.forward( + layouter.namespace(|| "dag fully connected"), + &vec_inps, + constants, + gadget_config.clone(), + &layer_config, + )? + } + LayerType::Softmax => { + let softmax_chip = SoftmaxChip {}; + softmax_chip.forward( + layouter.namespace(|| "dag softmax"), + &vec_inps, + constants, + gadget_config.clone(), + &layer_config, + )? + } + LayerType::Mean => { + let mean_chip = MeanChip {}; + mean_chip.forward( + layouter.namespace(|| "dag mean"), + &vec_inps, + constants, + gadget_config.clone(), + &layer_config, + )? + } + LayerType::Pad => { + let pad_chip = PadChip {}; + pad_chip.forward( + layouter.namespace(|| "dag pad"), + &vec_inps, + constants, + gadget_config.clone(), + &layer_config, + )? + } + LayerType::Permute => { + let pad_chip = PermuteChip {}; + pad_chip.forward( + layouter.namespace(|| "dag permute"), + &vec_inps, + constants, + gadget_config.clone(), + &layer_config, + )? + } + LayerType::SquaredDifference => { + let squared_diff_chip = SquaredDiffChip {}; + squared_diff_chip.forward( + layouter.namespace(|| "dag squared diff"), + &vec_inps, + constants, + gadget_config.clone(), + &layer_config, + )? + } + LayerType::Rsqrt => { + let rsqrt_chip = RsqrtChip {}; + rsqrt_chip.forward( + layouter.namespace(|| "dag rsqrt"), + &vec_inps, + constants, + gadget_config.clone(), + &layer_config, + )? + } + LayerType::Sqrt => { + let sqrt_chip = SqrtChip {}; + sqrt_chip.forward( + layouter.namespace(|| "dag sqrt"), + &vec_inps, + constants, + gadget_config.clone(), + &layer_config, + )? + } + LayerType::Logistic => { + let logistic_chip = LogisticChip {}; + logistic_chip.forward( + layouter.namespace(|| "dag logistic"), + &vec_inps, + constants, + gadget_config.clone(), + &layer_config, + )? + } + LayerType::Pow => { + let pow_chip = PowChip {}; + pow_chip.forward( + layouter.namespace(|| "dag logistic"), + &vec_inps, + constants, + gadget_config.clone(), + &layer_config, + )? + } + LayerType::Tanh => { + let tanh_chip = TanhChip {}; + tanh_chip.forward( + layouter.namespace(|| "dag tanh"), + &vec_inps, + constants, + gadget_config.clone(), + &layer_config, + )? + } + LayerType::Mul => { + let mul_chip = MulChip {}; + mul_chip.forward( + layouter.namespace(|| "dag mul"), + &vec_inps, + constants, + gadget_config.clone(), + &layer_config, + )? + } + LayerType::Sub => { + let sub_chip = SubChip {}; + sub_chip.forward( + layouter.namespace(|| "dag sub"), + &vec_inps, + constants, + gadget_config.clone(), + &layer_config, + )? + } + LayerType::Noop => { + let noop_chip = NoopChip {}; + noop_chip.forward( + layouter.namespace(|| "dag noop"), + &vec_inps, + constants, + gadget_config.clone(), + &layer_config, + )? + } + LayerType::Transpose => { + let transpose_chip = TransposeChip {}; + transpose_chip.forward( + layouter.namespace(|| "dag transpose"), + &vec_inps, + constants, + gadget_config.clone(), + &layer_config, + )? + } + LayerType::Reshape => { + let reshape_chip = ReshapeChip {}; + reshape_chip.forward( + layouter.namespace(|| "dag reshape"), + &vec_inps, + constants, + gadget_config.clone(), + &layer_config, + )? + } + LayerType::ResizeNN => { + let resize_nn_chip = ResizeNNChip {}; + resize_nn_chip.forward( + layouter.namespace(|| "dag resize nn"), + &vec_inps, + constants, + gadget_config.clone(), + &layer_config, + )? + } + LayerType::Rotate => { + let rotate_chip = RotateChip {}; + rotate_chip.forward( + layouter.namespace(|| "dag rotate"), + &vec_inps, + constants, + gadget_config.clone(), + &layer_config, + )? + } + LayerType::Concatenation => { + let concat_chip = ConcatenationChip {}; + concat_chip.forward( + layouter.namespace(|| "dag concatenation"), + &vec_inps, + constants, + gadget_config.clone(), + &layer_config, + )? + } + LayerType::Pack => { + let pack_chip = PackChip {}; + pack_chip.forward( + layouter.namespace(|| "dag pack"), + &vec_inps, + constants, + gadget_config.clone(), + &layer_config, + )? + } + LayerType::Split => { + let split_chip = SplitChip {}; + split_chip.forward( + layouter.namespace(|| "dag split"), + &vec_inps, + constants, + gadget_config.clone(), + &layer_config, + )? + } + LayerType::Update => { + let split_chip = UpdateChip {}; + split_chip.forward( + layouter.namespace(|| "dag update"), + &vec_inps, + constants, + gadget_config.clone(), + &layer_config, + )? + } + LayerType::Slice => { + let slice_chip = SliceChip {}; + slice_chip.forward( + layouter.namespace(|| "dag slice"), + &vec_inps, + constants, + gadget_config.clone(), + &layer_config, + )? + } + LayerType::MaskNegInf => { + let mask_neg_inf_chip = MaskNegInfChip {}; + mask_neg_inf_chip.forward( + layouter.namespace(|| "dag mask neg inf"), + &vec_inps, + constants, + gadget_config.clone(), + &layer_config, + )? + } + LayerType::Square => { + let square_chip = SquareChip {}; + square_chip.forward( + layouter.namespace(|| "dag square"), + &vec_inps, + constants, + gadget_config.clone(), + &layer_config, + )? + } + }; + + for (idx, tensor_idx) in out_idxes.iter().enumerate() { + println!("Out {} shape: {:?}", idx, out[idx].shape()); + tensor_map.insert(*tensor_idx, out[idx].clone()); + } + println!(); + } + + let mut final_out = vec![]; + for idx in self.dag_config.final_out_idxes.iter() { + final_out.push(tensor_map.get(idx).unwrap().clone()); + } + + let print_arr = if final_out.len() > 0 { + &final_out[0] + } else { + if self.dag_config.ops.len() > 0 { + let last_layer_idx = self.dag_config.ops.len() - 1; + let out_idx = self.dag_config.out_idxes[last_layer_idx][0]; + tensor_map.get(&out_idx).unwrap() + } else { + tensor_map.get(&0).unwrap() + } + }; + + let tmp = print_arr.iter().map(|x| x.as_ref()).collect::>(); + print_assigned_arr("final out", &tmp.to_vec(), gadget_config.scale_factor); + println!("final out idxes: {:?}", self.dag_config.final_out_idxes); + + let mut x = vec![]; + for cell in print_arr.iter() { + cell.value().map(|v| { + let bias = 1 << 60 as i64; + let v_pos = *v + F::from(bias as u64); + let v = convert_to_u64(&v_pos) as i64 - bias; + x.push(v); + }); + } + if x.len() > 0 { + let out_fname = "out.msgpack"; + let f = File::create(out_fname).unwrap(); + let mut buf = BufWriter::new(f); + rmp_serde::encode::write_named(&mut buf, &x).unwrap(); + } + + Ok((tensor_map, final_out)) + } +} + +impl GadgetConsumer for DAGLayerChip { + // Special case: DAG doesn't do anything + fn used_gadgets(&self, _layer_params: Vec) -> Vec { + vec![] + } +} diff --git a/mnist_zkml/src/layers/div_fixed.rs b/mnist_zkml/src/layers/div_fixed.rs new file mode 100644 index 0000000..2c86277 --- /dev/null +++ b/mnist_zkml/src/layers/div_fixed.rs @@ -0,0 +1,93 @@ +use std::{collections::HashMap, rc::Rc, vec}; + +use halo2_proofs::{ + circuit::{AssignedCell, Layouter, Value}, + halo2curves::ff::PrimeField, + plonk::Error, +}; +use ndarray::{Array, IxDyn}; + +use crate::gadgets::{ + gadget::{Gadget, GadgetConfig, GadgetType}, + var_div::VarDivRoundChip, +}; + +use super::layer::{AssignedTensor, CellRc, GadgetConsumer, Layer, LayerConfig}; + +#[derive(Clone, Debug)] +pub struct DivFixedChip {} + +impl DivFixedChip { + fn get_div_val( + &self, + mut layouter: impl Layouter, + _tensors: &Vec>, + gadget_config: Rc, + layer_config: &LayerConfig, + ) -> Result, Error> { + // FIXME: this needs to be revealed + let div = layer_config.layer_params[0]; + let div = F::from(div as u64); + + let div = layouter + .assign_region( + || "division", + |mut region| { + let div = region + .assign_advice( + || "avg pool 2d div", + gadget_config.columns[0], + 0, + || Value::known(div), + ) + .unwrap(); + Ok(div) + }, + ) + .unwrap(); + + Ok(div) + } +} + +impl Layer for DivFixedChip { + fn forward( + &self, + mut layouter: impl Layouter, + tensors: &Vec>, + constants: &HashMap>, + gadget_config: Rc, + layer_config: &LayerConfig, + ) -> Result>, Error> { + let inp = &tensors[0]; + let inp_flat = inp.iter().map(|x| x.as_ref()).collect::>(); + + let zero = constants.get(&0).unwrap().as_ref(); + let shape = inp.shape(); + + let div = self.get_div_val( + layouter.namespace(|| "average div"), + tensors, + gadget_config.clone(), + layer_config, + )?; + + let var_div_chip = VarDivRoundChip::::construct(gadget_config.clone()); + + let dived = var_div_chip.forward( + layouter.namespace(|| "average div"), + &vec![inp_flat], + &vec![zero, &div], + )?; + let dived = dived.into_iter().map(|x| Rc::new(x)).collect::>(); + let out = Array::from_shape_vec(IxDyn(shape), dived).unwrap(); + + Ok(vec![out]) + } +} + +impl GadgetConsumer for DivFixedChip { + fn used_gadgets(&self, _layer_params: Vec) -> Vec { + vec![GadgetType::VarDivRound] + } +} diff --git a/mnist_zkml/src/layers/fully_connected.rs b/mnist_zkml/src/layers/fully_connected.rs new file mode 100644 index 0000000..6dbf529 --- /dev/null +++ b/mnist_zkml/src/layers/fully_connected.rs @@ -0,0 +1,326 @@ +use std::{collections::HashMap, marker::PhantomData, rc::Rc}; + +use halo2_proofs::{ + circuit::{AssignedCell, Layouter, Region, Value}, + halo2curves::ff::PrimeField, + plonk::{Advice, Column, Error}, +}; +use ndarray::{Array, ArrayView, Axis, IxDyn}; + +use crate::{ + gadgets::{ + add_pairs::AddPairsChip, + dot_prod::DotProductChip, + gadget::{Gadget, GadgetConfig, GadgetType}, + nonlinear::relu::ReluChip, + var_div::VarDivRoundChip, + }, + layers::layer::ActivationType, + utils::helpers::RAND_START_IDX, +}; + +use super::layer::{AssignedTensor, CellRc, GadgetConsumer, Layer, LayerConfig}; + +pub struct FullyConnectedConfig { + pub normalize: bool, // Should be true +} + +impl FullyConnectedConfig { + pub fn construct(normalize: bool) -> Self { + Self { normalize } + } +} + +pub struct FullyConnectedChip { + pub _marker: PhantomData, + pub config: FullyConnectedConfig, +} + +impl FullyConnectedChip { + pub fn compute_mm( + // input: &AssignedTensor, + input: &ArrayView, IxDyn>, + weight: &AssignedTensor, + ) -> Array, IxDyn> { + assert_eq!(input.ndim(), 2); + assert_eq!(weight.ndim(), 2); + assert_eq!(input.shape()[1], weight.shape()[0]); + + let mut outp = vec![]; + for i in 0..input.shape()[0] { + for j in 0..weight.shape()[1] { + let mut sum = input[[i, 0]].value().map(|x: &F| *x) * weight[[0, j]].value(); + for k in 1..input.shape()[1] { + sum = sum + input[[i, k]].value().map(|x: &F| *x) * weight[[k, j]].value(); + } + outp.push(sum); + } + } + + let out_shape = [input.shape()[0], weight.shape()[1]]; + Array::from_shape_vec(IxDyn(out_shape.as_slice()), outp).unwrap() + } + + pub fn assign_array( + columns: &Vec>, + region: &mut Region, + array: &Array, IxDyn>, + ) -> Result, IxDyn>, Error> { + assert_eq!(array.ndim(), 2); + + let mut outp = vec![]; + for (idx, val) in array.iter().enumerate() { + let row_idx = idx / columns.len(); + let col_idx = idx % columns.len(); + let cell = region + .assign_advice(|| "assign array", columns[col_idx], row_idx, || *val) + .unwrap(); + outp.push(cell); + } + + let out_shape = [array.shape()[0], array.shape()[1]]; + Ok(Array::from_shape_vec(IxDyn(out_shape.as_slice()), outp).unwrap()) + } + + pub fn random_vector( + constants: &HashMap>, + size: usize, + ) -> Result>, Error> { + let mut outp = vec![]; + for idx in 0..size { + let idx = RAND_START_IDX + (idx as i64); + if !constants.contains_key(&idx) { + println!("Random vector is too small: {:?}", size); + } + let cell = constants.get(&idx).unwrap().clone(); + outp.push(cell); + } + + Ok(outp) + } + + fn get_activation(&self, layer_params: &Vec) -> ActivationType { + let activation = layer_params[0]; + match activation { + 0 => ActivationType::None, + 1 => ActivationType::Relu, + _ => panic!("Unsupported activation type for fully connected"), + } + } +} + +impl Layer for FullyConnectedChip { + fn forward( + &self, + mut layouter: impl Layouter, + tensors: &Vec>, + constants: &HashMap>, + gadget_config: Rc, + layer_config: &LayerConfig, + ) -> Result>, Error> { + assert!(tensors.len() <= 3); + let activation = self.get_activation(&layer_config.layer_params); + + let input = &tensors[0]; + let ndim = input.ndim(); + let input = if ndim == 2 { + ArrayView::from(input) + } else { + input.index_axis(Axis(0), 0) + }; + let weight = &tensors[1].t().into_owned(); + let zero = constants.get(&0).unwrap().as_ref(); + + // Compute and assign the result + let mm_result = layouter + .assign_region( + || "compute and assign mm", + |mut region| { + let mm_result = Self::compute_mm(&input, weight); + let mm_result = + Self::assign_array(&gadget_config.columns, &mut region, &mm_result).unwrap(); + + Ok(mm_result) + }, + ) + .unwrap(); + + // Generate random vectors + let r1 = Self::random_vector(constants, mm_result.shape()[0]).unwrap(); + let r2 = Self::random_vector(constants, mm_result.shape()[1]).unwrap(); + + let dot_prod_chip = DotProductChip::::construct(gadget_config.clone()); + let r1_ref = r1.iter().map(|x| x.as_ref()).collect::>(); + let r2_ref = r2.iter().map(|x| x.as_ref()).collect::>(); + + // Compute r1 * result + let mut r1_res = vec![]; + // println!("r1_ref: {:?}", r1_ref.len()); + // println!("r2_ref: {:?}", r2_ref.len()); + // println!("mm_result: {:?}", mm_result.shape()); + for i in 0..mm_result.shape()[1] { + let tmp = mm_result.index_axis(Axis(1), i); + let mm_ci = tmp.iter().collect::>(); + let r1_res_i = dot_prod_chip + .forward( + layouter.namespace(|| format!("r1_res_{}", i)), + &vec![mm_ci, r1_ref.clone()], + &vec![zero], + ) + .unwrap(); + r1_res.push(r1_res_i[0].clone()); + } + + // Compute r1 * result * r2 + let r1_res_ref = r1_res.iter().collect::>(); + let r1_res_r2 = dot_prod_chip + .forward( + layouter.namespace(|| "r1_res_r2"), + &vec![r1_res_ref, r2_ref.clone()], + &vec![zero], + ) + .unwrap(); + let r1_res_r2 = r1_res_r2[0].clone(); + // println!("r1_res_r2: {:?}", r1_res_r2); + + // Compute r1 * input + let mut r1_input = vec![]; + // println!("input: {:?}", input.shape()); + // println!("r1_ref: {:?}", r1_ref.len()); + for i in 0..input.shape()[1] { + let tmp = input.index_axis(Axis(1), i); + let input_ci = tmp.iter().map(|x| x.as_ref()).collect::>(); + let r1_input_i = dot_prod_chip + .forward( + layouter.namespace(|| format!("r1_input_{}", i)), + &vec![input_ci, r1_ref.clone()], + &vec![zero], + ) + .unwrap(); + r1_input.push(r1_input_i[0].clone()); + } + + // Compute weight * r2 + let mut weight_r2 = vec![]; + for i in 0..weight.shape()[0] { + let tmp = weight.index_axis(Axis(0), i); + let weight_ci = tmp.iter().map(|x| x.as_ref()).collect::>(); + let weight_r2_i = dot_prod_chip + .forward( + layouter.namespace(|| format!("weight_r2_{}", i)), + &vec![weight_ci, r2_ref.clone()], + &vec![zero], + ) + .unwrap(); + weight_r2.push(weight_r2_i[0].clone()); + } + + // Compute (r1 * input) * (weight * r2) + let r1_input_ref = r1_input.iter().collect::>(); + let weight_r2_ref = weight_r2.iter().collect::>(); + let r1_inp_weight_r2 = dot_prod_chip + .forward( + layouter.namespace(|| "r1_inp_weight_r2"), + &vec![r1_input_ref, weight_r2_ref], + &vec![zero], + ) + .unwrap(); + + let r1_inp_weight_r2 = r1_inp_weight_r2[0].clone(); + // println!("r1_inp_weight_r2: {:?}", r1_inp_weight_r2); + + layouter + .assign_region( + || "fc equality check", + |mut region| { + let t1 = r1_res_r2 + .copy_advice(|| "", &mut region, gadget_config.columns[0], 0) + .unwrap(); + let t2 = r1_inp_weight_r2 + .copy_advice(|| "", &mut region, gadget_config.columns[0], 1) + .unwrap(); + + region.constrain_equal(t1.cell(), t2.cell()).unwrap(); + + Ok(()) + }, + ) + .unwrap(); + + let shape = [mm_result.shape()[0], mm_result.shape()[1]]; + let final_result_flat = if self.config.normalize { + let mm_flat = mm_result.iter().collect::>(); + let var_div_chip = VarDivRoundChip::::construct(gadget_config.clone()); + let sf = constants + .get(&(gadget_config.scale_factor as i64)) + .unwrap() + .as_ref(); + let mm_div = var_div_chip + .forward( + layouter.namespace(|| "mm_div"), + &vec![mm_flat], + &vec![zero, sf], + ) + .unwrap(); + + let mm_div = if tensors.len() == 3 { + let bias = tensors[2].broadcast(shape.clone()).unwrap(); + let bias = bias.iter().map(|x| x.as_ref()).collect::>(); + let mm_div = mm_div.iter().collect::>(); + let adder_chip = AddPairsChip::::construct(gadget_config.clone()); + let mm_bias = adder_chip + .forward( + layouter.namespace(|| "mm_bias"), + &vec![mm_div, bias], + &vec![zero], + ) + .unwrap(); + mm_bias + } else { + mm_div + }; + + let mm_div = if activation == ActivationType::Relu { + let relu_chip = ReluChip::::construct(gadget_config.clone()); + let mm_div = mm_div.iter().collect::>(); + let vec_inputs = vec![mm_div]; + relu_chip + .forward(layouter.namespace(|| "relu"), &vec_inputs, &vec![zero]) + .unwrap() + } else if activation == ActivationType::None { + mm_div + } else { + panic!("Unsupported activation type"); + }; + + mm_div.into_iter().map(|x| Rc::new(x)).collect::>() + } else { + mm_result + .into_iter() + .map(|x| Rc::new(x)) + .collect::>() + }; + let final_result = Array::from_shape_vec(IxDyn(&shape), final_result_flat).unwrap(); + + Ok(vec![final_result]) + } +} + +impl GadgetConsumer for FullyConnectedChip { + fn used_gadgets(&self, layer_params: Vec) -> Vec { + let activation = self.get_activation(&layer_params); + let mut outp = vec![ + GadgetType::Adder, + GadgetType::AddPairs, + GadgetType::DotProduct, + GadgetType::VarDivRound, + GadgetType::InputLookup, + ]; + match activation { + ActivationType::Relu => outp.push(GadgetType::Relu), + ActivationType::None => (), + _ => panic!("Unsupported activation type"), + } + outp + } +} diff --git a/mnist_zkml/src/layers/layer.rs b/mnist_zkml/src/layers/layer.rs new file mode 100644 index 0000000..676d6eb --- /dev/null +++ b/mnist_zkml/src/layers/layer.rs @@ -0,0 +1,89 @@ +use std::{collections::HashMap, rc::Rc}; + +use halo2_proofs::{ + circuit::{AssignedCell, Layouter}, + halo2curves::ff::PrimeField, + plonk::Error, +}; +use ndarray::{Array, IxDyn}; + +use crate::gadgets::gadget::{GadgetConfig, GadgetType}; + +#[derive(Clone, Copy, Debug, Default, Hash, Eq, PartialEq)] +pub enum LayerType { + Add, + AvgPool2D, + BatchMatMul, + Broadcast, + Concatenation, + Conv2D, + DivVar, + DivFixed, + FullyConnected, + Logistic, + MaskNegInf, + MaxPool2D, + Mean, + Mul, + #[default] + Noop, + Pack, + Pad, + Pow, + Permute, + Reshape, + ResizeNN, + Rotate, + Rsqrt, + Slice, + Softmax, + Split, + Sqrt, + Square, + SquaredDifference, + Sub, + Tanh, + Transpose, + Update, +} + +// NOTE: This is the same order as the TFLite schema +// Must not be changed +#[derive(Clone, Debug, Default, Hash, Eq, PartialEq)] +pub enum ActivationType { + #[default] + None, + Relu, + ReluN1To1, + Relu6, + Tanh, + SignBit, +} + +#[derive(Clone, Debug, Default)] +pub struct LayerConfig { + pub layer_type: LayerType, + pub layer_params: Vec, // This is turned into layer specific configurations at runtime + pub inp_shapes: Vec>, + pub out_shapes: Vec>, + pub mask: Vec, +} + +pub type CellRc = Rc>; +pub type AssignedTensor = Array, IxDyn>; +// General issue with rust: I'm not sure how to pass named arguments to a trait... +// Currently, the caller must be aware of the order of the tensors and results +pub trait Layer { + fn forward( + &self, + layouter: impl Layouter, + tensors: &Vec>, + constants: &HashMap>, + gadget_config: Rc, + layer_config: &LayerConfig, + ) -> Result>, Error>; +} + +pub trait GadgetConsumer { + fn used_gadgets(&self, layer_params: Vec) -> Vec; +} diff --git a/mnist_zkml/src/layers/logistic.rs b/mnist_zkml/src/layers/logistic.rs new file mode 100644 index 0000000..6ff9f45 --- /dev/null +++ b/mnist_zkml/src/layers/logistic.rs @@ -0,0 +1,49 @@ +use std::{collections::HashMap, rc::Rc, vec}; + +use halo2_proofs::{circuit::Layouter, halo2curves::ff::PrimeField, plonk::Error}; +use ndarray::{Array, IxDyn}; + +use crate::gadgets::{ + gadget::{Gadget, GadgetConfig, GadgetType}, + nonlinear::logistic::LogisticGadgetChip, +}; + +use super::layer::{AssignedTensor, CellRc, GadgetConsumer, Layer, LayerConfig}; + +#[derive(Clone, Debug)] +pub struct LogisticChip {} + +impl Layer for LogisticChip { + fn forward( + &self, + mut layouter: impl Layouter, + tensors: &Vec>, + constants: &HashMap>, + gadget_config: Rc, + _layer_config: &LayerConfig, + ) -> Result>, Error> { + let inp = &tensors[0]; + let inp_vec = inp.iter().map(|x| x.as_ref()).collect::>(); + let zero = constants.get(&0).unwrap().as_ref(); + + let logistic_chip = LogisticGadgetChip::::construct(gadget_config.clone()); + let vec_inps = vec![inp_vec]; + let constants = vec![zero]; + let out = logistic_chip.forward( + layouter.namespace(|| "logistic chip"), + &vec_inps, + &constants, + )?; + + let out = out.into_iter().map(|x| Rc::new(x)).collect::>(); + let out = Array::from_shape_vec(IxDyn(inp.shape()), out).unwrap(); + + Ok(vec![out]) + } +} + +impl GadgetConsumer for LogisticChip { + fn used_gadgets(&self, _layer_params: Vec) -> Vec { + vec![GadgetType::Logistic, GadgetType::InputLookup] + } +} diff --git a/mnist_zkml/src/layers/max_pool_2d.rs b/mnist_zkml/src/layers/max_pool_2d.rs new file mode 100644 index 0000000..c929092 --- /dev/null +++ b/mnist_zkml/src/layers/max_pool_2d.rs @@ -0,0 +1,124 @@ +use std::{collections::HashMap, rc::Rc}; + +use halo2_proofs::{circuit::Layouter, halo2curves::ff::PrimeField, plonk::Error}; +use ndarray::{Array, IxDyn}; + +use crate::{ + gadgets::{ + gadget::{Gadget, GadgetConfig, GadgetType}, + max::MaxChip, + }, + layers::conv2d::{Conv2DChip, PaddingEnum}, +}; + +use super::layer::{AssignedTensor, CellRc, GadgetConsumer, Layer, LayerConfig}; + +pub struct MaxPool2DChip { + pub marker: std::marker::PhantomData, +} + +impl MaxPool2DChip { + pub fn shape(inp: &AssignedTensor, layer_config: &LayerConfig) -> (usize, usize) { + let params = &layer_config.layer_params; + let (fx, fy) = (params[0], params[1]); + let (fx, fy) = (fx as usize, fy as usize); + let (sx, sy) = (params[2], params[3]); + let (sx, sy) = (sx as usize, sy as usize); + + // Only support batch size 1 for now + assert_eq!(inp.shape()[0], 1); + + let out_shape = Conv2DChip::::out_hw( + inp.shape()[1], + inp.shape()[2], + sx, + sy, + fx, + fy, + PaddingEnum::Valid, + ); + + out_shape + } + + pub fn splat( + inp: &AssignedTensor, + layer_config: &LayerConfig, + ) -> Result>>, Error> { + let params = &layer_config.layer_params; + let (fx, fy) = (params[0], params[1]); + let (fx, fy) = (fx as usize, fy as usize); + let (sx, sy) = (params[2], params[3]); + let (sx, sy) = (sx as usize, sy as usize); + + // Only support batch size 1 for now + assert_eq!(inp.shape()[0], 1); + + let out_shape = Self::shape(inp, layer_config); + + let mut splat = vec![]; + for i in 0..out_shape.0 { + for j in 0..out_shape.1 { + for k in 0..inp.shape()[3] { + let mut tmp = vec![]; + for x in 0..fx { + for y in 0..fy { + let x = i * sx + x; + let y = j * sy + y; + if x < inp.shape()[1] && y < inp.shape()[2] { + tmp.push(inp[[0, x, y, k]].clone()); + } + } + } + splat.push(tmp); + } + } + } + + Ok(splat) + } +} + +impl Layer for MaxPool2DChip { + fn forward( + &self, + mut layouter: impl Layouter, + tensors: &Vec>, + _constants: &HashMap>, + gadget_config: Rc, + layer_config: &LayerConfig, + ) -> Result>, Error> { + let inp = &tensors[0]; + let splat = Self::splat(inp, layer_config).unwrap(); + + let max_chip = MaxChip::::construct(gadget_config.clone()); + let mut out = vec![]; + for i in 0..splat.len() { + let inps = &splat[i]; + let inps = inps.iter().map(|x| x.as_ref()).collect(); + let max = max_chip + .forward( + layouter.namespace(|| format!("max {}", i)), + &vec![inps], + &vec![], + ) + .unwrap(); + out.push(max[0].clone()); + } + let out = out.into_iter().map(|x| Rc::new(x)).collect(); + + // TODO: refactor this + let out_xy = Self::shape(inp, layer_config); + let out_shape = vec![1, out_xy.0, out_xy.1, inp.shape()[3]]; + + let out = Array::from_shape_vec(IxDyn(&out_shape), out).unwrap(); + + Ok(vec![out]) + } +} + +impl GadgetConsumer for MaxPool2DChip { + fn used_gadgets(&self, _layer_params: Vec) -> Vec { + vec![GadgetType::Max, GadgetType::InputLookup] + } +} diff --git a/mnist_zkml/src/layers/mean.rs b/mnist_zkml/src/layers/mean.rs new file mode 100644 index 0000000..bfaf2f1 --- /dev/null +++ b/mnist_zkml/src/layers/mean.rs @@ -0,0 +1,139 @@ +use std::{collections::HashMap, rc::Rc}; + +use halo2_proofs::{ + circuit::{AssignedCell, Layouter, Value}, + halo2curves::ff::PrimeField, + plonk::Error, +}; +use ndarray::{Array, Axis, IxDyn}; + +use crate::gadgets::gadget::{GadgetConfig, GadgetType}; + +use super::{ + averager::Averager, + layer::{AssignedTensor, CellRc, GadgetConsumer, Layer, LayerConfig}, +}; + +pub struct MeanChip {} + +impl MeanChip { + pub fn get_keep_axis(&self, layer_config: &LayerConfig) -> usize { + let inp_shape = &layer_config.inp_shapes[0]; + let out_shape = &layer_config.out_shapes[0]; + assert_eq!(inp_shape[0], 1); + assert_eq!(out_shape[0], 1); + + // Skip the batch axis + let mut keep_axes = (1..inp_shape.len()).collect::>(); + for mean_axis in layer_config.layer_params.iter() { + keep_axes.retain(|&x| x != *mean_axis as usize); + } + assert_eq!(keep_axes.len(), 1); + keep_axes[0] + + /* + let mut num_same = 0; + let mut keep_axis: i64 = -1; + for i in 1..inp_shape.len() { + if inp_shape[i] == out_shape[i] { + keep_axis = i as i64; + num_same += 1; + } + } + + if keep_axis == -1 { + panic!("All axes are different"); + } + if num_same > 1 { + panic!("More than one axis is the same"); + } + keep_axis as usize + */ + } +} + +impl Averager for MeanChip { + fn splat(&self, input: &AssignedTensor, layer_config: &LayerConfig) -> Vec>> { + // Only support batch size = 1 + assert_eq!(input.shape()[0], 1); + // Only support batch + 2D, summing over one axis + // assert_eq!(input.shape().len(), 3); + let keep_axis = self.get_keep_axis(layer_config); + + let mut splat = vec![]; + for i in 0..input.shape()[keep_axis] { + let mut tmp = vec![]; + for x in input.index_axis(Axis(keep_axis), i).iter() { + tmp.push(x.clone()); + } + splat.push(tmp); + } + + splat + } + + fn get_div_val( + &self, + mut layouter: impl Layouter, + tensors: &Vec>, + gadget_config: Rc, + layer_config: &LayerConfig, + ) -> Result, Error> { + let inp = &tensors[0]; + let keep_axis = self.get_keep_axis(layer_config); + let mut div = 1; + for i in 0..inp.shape().len() { + if i != keep_axis { + div *= inp.shape()[i]; + } + } + + let div = F::from(div as u64); + // FIXME: put this in the fixed column + let div = layouter.assign_region( + || "mean div", + |mut region| { + let div = region.assign_advice( + || "mean div", + gadget_config.columns[0], + 0, + || Value::known(div), + )?; + Ok(div) + }, + )?; + + Ok(div) + } +} + +impl Layer for MeanChip { + fn forward( + &self, + layouter: impl Layouter, + tensors: &Vec>, + constants: &HashMap>, + gadget_config: Rc, + layer_config: &LayerConfig, + ) -> Result>, Error> { + let dived = self.avg_forward(layouter, tensors, constants, gadget_config, layer_config)?; + + let out_shape = layer_config.out_shapes[0] + .iter() + .map(|x| *x as usize) + .collect::>(); + + let out = Array::from_shape_vec(IxDyn(&out_shape), dived).unwrap(); + Ok(vec![out]) + } +} + +impl GadgetConsumer for MeanChip { + fn used_gadgets(&self, _layer_params: Vec) -> Vec { + vec![ + GadgetType::Adder, + GadgetType::VarDivRound, + GadgetType::InputLookup, + ] + } +} diff --git a/mnist_zkml/src/layers/noop.rs b/mnist_zkml/src/layers/noop.rs new file mode 100644 index 0000000..0e01f38 --- /dev/null +++ b/mnist_zkml/src/layers/noop.rs @@ -0,0 +1,29 @@ +use std::{collections::HashMap, rc::Rc}; + +use halo2_proofs::{circuit::Layouter, halo2curves::ff::PrimeField, plonk::Error}; + +use crate::gadgets::gadget::GadgetConfig; + +use super::layer::{AssignedTensor, CellRc, GadgetConsumer, Layer, LayerConfig}; + +pub struct NoopChip {} + +impl Layer for NoopChip { + fn forward( + &self, + _layouter: impl Layouter, + tensors: &Vec>, + _constants: &HashMap>, + _gadget_config: Rc, + layer_config: &LayerConfig, + ) -> Result>, Error> { + let ret_idx = layer_config.layer_params[0] as usize; + Ok(vec![tensors[ret_idx].clone()]) + } +} + +impl GadgetConsumer for NoopChip { + fn used_gadgets(&self, _layer_params: Vec) -> Vec { + vec![] + } +} diff --git a/mnist_zkml/src/layers/pow.rs b/mnist_zkml/src/layers/pow.rs new file mode 100644 index 0000000..7d4c443 --- /dev/null +++ b/mnist_zkml/src/layers/pow.rs @@ -0,0 +1,45 @@ +use std::{collections::HashMap, rc::Rc, vec}; + +use halo2_proofs::{circuit::Layouter, halo2curves::ff::PrimeField, plonk::Error}; +use ndarray::{Array, IxDyn}; + +use crate::gadgets::{ + gadget::{Gadget, GadgetConfig, GadgetType}, + nonlinear::pow::PowGadgetChip, +}; + +use super::layer::{AssignedTensor, CellRc, GadgetConsumer, Layer, LayerConfig}; + +#[derive(Clone, Debug)] +pub struct PowChip {} + +impl Layer for PowChip { + fn forward( + &self, + mut layouter: impl Layouter, + tensors: &Vec>, + constants: &HashMap>, + gadget_config: Rc, + _layer_config: &LayerConfig, + ) -> Result>, Error> { + let inp = &tensors[0]; + let inp_vec = inp.iter().map(|x| x.as_ref()).collect::>(); + let zero = constants.get(&0).unwrap().as_ref(); + + let pow_chip = PowGadgetChip::::construct(gadget_config.clone()); + let vec_inps = vec![inp_vec]; + let constants = vec![zero]; + let out = pow_chip.forward(layouter.namespace(|| "pow chip"), &vec_inps, &constants)?; + + let out = out.into_iter().map(|x| Rc::new(x)).collect::>(); + let out = Array::from_shape_vec(IxDyn(inp.shape()), out).unwrap(); + + Ok(vec![out]) + } +} + +impl GadgetConsumer for PowChip { + fn used_gadgets(&self, _layer_params: Vec) -> Vec { + vec![GadgetType::Pow, GadgetType::InputLookup] + } +} diff --git a/mnist_zkml/src/layers/rsqrt.rs b/mnist_zkml/src/layers/rsqrt.rs new file mode 100644 index 0000000..bbc6e1c --- /dev/null +++ b/mnist_zkml/src/layers/rsqrt.rs @@ -0,0 +1,71 @@ +use std::{collections::HashMap, rc::Rc, vec}; + +use halo2_proofs::{circuit::Layouter, halo2curves::ff::PrimeField, plonk::Error}; +use ndarray::{Array, IxDyn}; + +use crate::gadgets::{ + gadget::{Gadget, GadgetConfig, GadgetType}, + nonlinear::rsqrt::RsqrtGadgetChip, +}; + +use super::layer::{AssignedTensor, CellRc, GadgetConsumer, Layer, LayerConfig}; + +#[derive(Clone, Debug)] +pub struct RsqrtChip {} + +impl Layer for RsqrtChip { + fn forward( + &self, + mut layouter: impl Layouter, + tensors: &Vec>, + constants: &HashMap>, + gadget_config: Rc, + layer_config: &LayerConfig, + ) -> Result>, Error> { + let inp = &tensors[0]; + let mut inp_vec = vec![]; + + let mask = &layer_config.mask; + let mut mask_map = HashMap::new(); + for i in 0..mask.len() / 2 { + mask_map.insert(mask[2 * i], mask[2 * i + 1]); + } + + let min_val = gadget_config.min_val; + let min_val = constants.get(&min_val).unwrap().as_ref(); + let max_val = gadget_config.max_val; + let max_val = constants.get(&max_val).unwrap().as_ref(); + for (i, val) in inp.iter().enumerate() { + let i = i as i64; + if mask_map.contains_key(&i) { + let mask_val = *mask_map.get(&i).unwrap(); + if mask_val == 1 { + inp_vec.push(max_val); + } else if mask_val == -1 { + inp_vec.push(min_val); + } else { + panic!(); + } + } else { + inp_vec.push(val.as_ref()); + } + } + + let zero = constants.get(&0).unwrap().as_ref(); + let rsqrt_chip = RsqrtGadgetChip::::construct(gadget_config.clone()); + let vec_inps = vec![inp_vec]; + let constants = vec![zero, min_val, max_val]; + let out = rsqrt_chip.forward(layouter.namespace(|| "rsqrt chip"), &vec_inps, &constants)?; + + let out = out.into_iter().map(|x| Rc::new(x)).collect::>(); + let out = Array::from_shape_vec(IxDyn(inp.shape()), out).unwrap(); + + Ok(vec![out]) + } +} + +impl GadgetConsumer for RsqrtChip { + fn used_gadgets(&self, _layer_params: Vec) -> Vec { + vec![GadgetType::Rsqrt, GadgetType::InputLookup] + } +} diff --git a/mnist_zkml/src/layers/shape.rs b/mnist_zkml/src/layers/shape.rs new file mode 100644 index 0000000..12aa85c --- /dev/null +++ b/mnist_zkml/src/layers/shape.rs @@ -0,0 +1,12 @@ +pub mod broadcast; +pub mod concatenation; +pub mod mask_neg_inf; +pub mod pack; +pub mod pad; +pub mod permute; +pub mod reshape; +pub mod resize_nn; +pub mod rotate; +pub mod slice; +pub mod split; +pub mod transpose; diff --git a/mnist_zkml/src/layers/shape/broadcast.rs b/mnist_zkml/src/layers/shape/broadcast.rs new file mode 100644 index 0000000..c947bdf --- /dev/null +++ b/mnist_zkml/src/layers/shape/broadcast.rs @@ -0,0 +1,71 @@ +// +// Broadcast is used as a temporary measure to represent a the backprop +// of a full-kernel AvgPool2D +// + +use std::{collections::HashMap, rc::Rc}; + +use halo2_proofs::{circuit::Layouter, halo2curves::ff::PrimeField, plonk::Error}; +use ndarray::Array; + +use crate::{ + gadgets::gadget::GadgetConfig, + layers::layer::{AssignedTensor, CellRc, GadgetConsumer}, +}; + +use super::super::layer::{Layer, LayerConfig}; + +pub struct BroadcastChip {} + +// TODO: Fix this after demo +impl Layer for BroadcastChip { + fn forward( + &self, + _layouter: impl Layouter, + tensors: &Vec>, + _constants: &HashMap>, + _gadget_config: Rc, + layer_config: &LayerConfig, + ) -> Result>, Error> { + let inp = &tensors[0]; + let shape = inp.shape(); + let output_shape = layer_config.out_shapes[0].clone(); + + // Check that we only broadcast dimensions with shape 1 + assert!(shape.len() == output_shape.len()); + assert!(shape.len() == 4); + + for (inp, outp) in shape.iter().zip(output_shape.iter()) { + if *inp != *outp && !(*inp == 1) { + panic!(); + } + } + + let mut output_flat = vec![]; + + for i in 0..output_shape[0] { + for j in 0..output_shape[1] { + for k in 0..output_shape[2] { + for l in 0..output_shape[3] { + let indexes = [i, j, k, l] + .iter() + .enumerate() + .map(|(idx, x)| if shape[idx] == 1 { 0 } else { *x }) + .collect::>(); + output_flat.push(inp[[indexes[0], indexes[1], indexes[2], indexes[3]]].clone()); + } + } + } + } + + println!("Broadcast : {:?} -> {:?}", inp.shape(), output_shape); + let out = Array::from_shape_vec(output_shape, output_flat).unwrap(); + Ok(vec![out]) + } +} + +impl GadgetConsumer for BroadcastChip { + fn used_gadgets(&self, _layer_params: Vec) -> Vec { + vec![] + } +} diff --git a/mnist_zkml/src/layers/shape/concatenation.rs b/mnist_zkml/src/layers/shape/concatenation.rs new file mode 100644 index 0000000..e81bdf9 --- /dev/null +++ b/mnist_zkml/src/layers/shape/concatenation.rs @@ -0,0 +1,37 @@ +use std::{collections::HashMap, rc::Rc}; + +use halo2_proofs::{circuit::Layouter, halo2curves::ff::PrimeField, plonk::Error}; +use ndarray::{concatenate, Axis}; + +use crate::{ + gadgets::gadget::{GadgetConfig, GadgetType}, + layers::layer::{AssignedTensor, CellRc, GadgetConsumer}, +}; + +use super::super::layer::{Layer, LayerConfig}; + +pub struct ConcatenationChip {} + +impl Layer for ConcatenationChip { + fn forward( + &self, + _layouter: impl Layouter, + tensors: &Vec>, + _constants: &HashMap>, + _gadget_config: Rc, + layer_config: &LayerConfig, + ) -> Result>, Error> { + let axis = layer_config.layer_params[0] as usize; + let views = tensors.iter().map(|x| x.view()).collect::>(); + // TODO: this is a bit of a hack + let out = concatenate(Axis(axis), views.as_slice()).unwrap_or(tensors[0].clone()); + + Ok(vec![out]) + } +} + +impl GadgetConsumer for ConcatenationChip { + fn used_gadgets(&self, _layer_params: Vec) -> Vec { + vec![] + } +} diff --git a/mnist_zkml/src/layers/shape/mask_neg_inf.rs b/mnist_zkml/src/layers/shape/mask_neg_inf.rs new file mode 100644 index 0000000..b4cbab2 --- /dev/null +++ b/mnist_zkml/src/layers/shape/mask_neg_inf.rs @@ -0,0 +1,55 @@ +use std::{collections::HashMap, rc::Rc}; + +use halo2_proofs::{circuit::Layouter, halo2curves::ff::PrimeField, plonk::Error}; +use ndarray::{Array, IxDyn}; + +use crate::{ + gadgets::gadget::GadgetConfig, + layers::layer::{AssignedTensor, CellRc, GadgetConsumer}, +}; + +use super::super::layer::{Layer, LayerConfig}; + +pub struct MaskNegInfChip {} + +impl Layer for MaskNegInfChip { + fn forward( + &self, + _layouter: impl Layouter, + tensors: &Vec>, + constants: &HashMap>, + gadget_config: Rc, + layer_config: &LayerConfig, + ) -> Result>, Error> { + let inp = &tensors[0]; + let mask_ndim = layer_config.layer_params[0] as usize; + let mask_shape = layer_config.layer_params[1..mask_ndim + 1] + .iter() + .map(|x| *x as usize) + .collect::>(); + + let mask_vec = layer_config.layer_params[mask_ndim + 1..].to_vec(); + let mask = Array::from_shape_vec(IxDyn(&mask_shape), mask_vec).unwrap(); + let mask = mask.broadcast(inp.raw_dim()).unwrap(); + + let min_val = gadget_config.min_val; + let min_val = constants.get(&min_val).unwrap().clone(); + let mut out_vec = vec![]; + for (val, to_mask) in inp.iter().zip(mask.iter()) { + if *to_mask == 0 { + out_vec.push(val.clone()); + } else { + out_vec.push(min_val.clone()); + } + } + + let outp = Array::from_shape_vec(inp.raw_dim(), out_vec).unwrap(); + Ok(vec![outp]) + } +} + +impl GadgetConsumer for MaskNegInfChip { + fn used_gadgets(&self, _layer_params: Vec) -> Vec { + vec![] + } +} diff --git a/mnist_zkml/src/layers/shape/pack.rs b/mnist_zkml/src/layers/shape/pack.rs new file mode 100644 index 0000000..e06ba75 --- /dev/null +++ b/mnist_zkml/src/layers/shape/pack.rs @@ -0,0 +1,46 @@ +use std::{collections::HashMap, rc::Rc}; + +use halo2_proofs::{circuit::Layouter, halo2curves::ff::PrimeField, plonk::Error}; +use ndarray::{concatenate, Axis}; + +use crate::{ + gadgets::gadget::{GadgetConfig, GadgetType}, + layers::layer::{AssignedTensor, CellRc, GadgetConsumer}, +}; + +use super::super::layer::{Layer, LayerConfig}; + +pub struct PackChip {} + +impl Layer for PackChip { + fn forward( + &self, + _layouter: impl Layouter, + tensors: &Vec>, + _constants: &HashMap>, + _gadget_config: Rc, + layer_config: &LayerConfig, + ) -> Result>, Error> { + let axis = layer_config.layer_params[0] as usize; + if axis > 1 { + panic!("Pack only supports axis=0 or axis=1"); + } + + let expanded = tensors + .into_iter() + .map(|x| x.clone().insert_axis(Axis(axis))) + .collect::>(); + let views = expanded.iter().map(|x| x.view()).collect::>(); + + // TODO: in some cases, the pack is unnecessary. Simply return the first tensor in this case + let out = concatenate(Axis(axis), views.as_slice()).unwrap_or(tensors[0].clone()); + + Ok(vec![out]) + } +} + +impl GadgetConsumer for PackChip { + fn used_gadgets(&self, _layer_params: Vec) -> Vec { + vec![] + } +} diff --git a/mnist_zkml/src/layers/shape/pad.rs b/mnist_zkml/src/layers/shape/pad.rs new file mode 100644 index 0000000..fae9c74 --- /dev/null +++ b/mnist_zkml/src/layers/shape/pad.rs @@ -0,0 +1,97 @@ +use std::{collections::HashMap, rc::Rc}; + +use halo2_proofs::{ + circuit::{AssignedCell, Layouter}, + halo2curves::ff::PrimeField, + plonk::Error, +}; +use ndarray::{Array, Axis, IxDyn, Slice}; + +use crate::{ + gadgets::gadget::GadgetConfig, + layers::layer::{AssignedTensor, GadgetConsumer}, +}; + +use super::super::layer::{Layer, LayerConfig}; + +// TODO: figure out where to put this +pub fn pad( + input: &Array, IxDyn>, + padding: Vec<[usize; 2]>, + pad_val: &Rc, +) -> Array, IxDyn> { + let tmp = input.iter().collect(); + let input = Array::from_shape_vec(input.raw_dim(), tmp).unwrap(); + assert_eq!(input.ndim(), padding.len()); + let mut padded_shape = input.raw_dim(); + for (ax, (&ax_len, &[pad_lo, pad_hi])) in input.shape().iter().zip(&padding).enumerate() { + padded_shape[ax] = ax_len + pad_lo + pad_hi; + } + + let mut padded = Array::from_elem(padded_shape, pad_val); + let padded_dim = padded.raw_dim(); + { + // Select portion of padded array that needs to be copied from the + // original array. + let mut orig_portion = padded.view_mut(); + for (ax, &[pad_lo, pad_hi]) in padding.iter().enumerate() { + orig_portion.slice_axis_inplace( + Axis(ax), + Slice::from(pad_lo as isize..padded_dim[ax] as isize - (pad_hi as isize)), + ); + } + // Copy the data from the original array. + orig_portion.assign(&input.view()); + } + + let dim = padded.raw_dim(); + let tmp = padded.into_iter().map(|x| x.clone()).collect(); + let padded = Array::from_shape_vec(dim, tmp).unwrap(); + + padded +} + +pub struct PadChip {} + +pub struct PadConfig { + pub padding: Vec<[usize; 2]>, +} + +impl PadChip { + pub fn param_vec_to_config(layer_params: Vec) -> PadConfig { + assert!(layer_params.len() % 2 == 0); + + let padding = layer_params + .chunks(2) + .map(|chunk| [chunk[0] as usize, chunk[1] as usize]) + .collect(); + PadConfig { padding } + } +} + +impl Layer for PadChip { + fn forward( + &self, + _layouter: impl Layouter, + tensors: &Vec>, + constants: &HashMap>>, + _gadget_config: Rc, + layer_config: &LayerConfig, + ) -> Result>, Error> { + // FIXME: the pad from tflite is actually two, but mine is one + // assert_eq!(tensors.len(), 1); + let input = &tensors[0]; + + let zero = constants.get(&0).unwrap().clone(); + let padding = PadChip::param_vec_to_config(layer_config.layer_params.clone()); + let padded = pad(input, padding.padding, &zero); + + Ok(vec![padded]) + } +} + +impl GadgetConsumer for PadChip { + fn used_gadgets(&self, _layer_params: Vec) -> Vec { + vec![] + } +} diff --git a/mnist_zkml/src/layers/shape/permute.rs b/mnist_zkml/src/layers/shape/permute.rs new file mode 100644 index 0000000..90121b3 --- /dev/null +++ b/mnist_zkml/src/layers/shape/permute.rs @@ -0,0 +1,43 @@ +use std::{collections::HashMap, rc::Rc}; + +use halo2_proofs::{circuit::Layouter, halo2curves::ff::PrimeField, plonk::Error}; +use ndarray::IxDyn; + +use crate::{ + gadgets::gadget::GadgetConfig, + layers::layer::{AssignedTensor, CellRc, GadgetConsumer}, +}; + +use super::super::layer::{Layer, LayerConfig}; + +pub struct PermuteChip {} + +impl Layer for PermuteChip { + fn forward( + &self, + _layouter: impl Layouter, + tensors: &Vec>, + _constants: &HashMap>, + _gadget_config: Rc, + layer_config: &LayerConfig, + ) -> Result>, Error> { + let inp = &tensors[0]; + let params = &layer_config + .layer_params + .iter() + .map(|x| *x as usize) + .collect::>()[..]; + + assert!(inp.ndim() == params.len()); + + let out = inp.clone(); + let out = out.permuted_axes(IxDyn(params)); + Ok(vec![out]) + } +} + +impl GadgetConsumer for PermuteChip { + fn used_gadgets(&self, _layer_params: Vec) -> Vec { + vec![] + } +} diff --git a/mnist_zkml/src/layers/shape/reshape.rs b/mnist_zkml/src/layers/shape/reshape.rs new file mode 100644 index 0000000..abbbfe6 --- /dev/null +++ b/mnist_zkml/src/layers/shape/reshape.rs @@ -0,0 +1,38 @@ +use std::{collections::HashMap, rc::Rc}; + +use halo2_proofs::{circuit::Layouter, halo2curves::ff::PrimeField, plonk::Error}; +use ndarray::Array; + +use crate::{ + gadgets::gadget::GadgetConfig, + layers::layer::{AssignedTensor, CellRc, GadgetConsumer}, +}; + +use super::super::layer::{Layer, LayerConfig}; + +pub struct ReshapeChip {} + +impl Layer for ReshapeChip { + fn forward( + &self, + _layouter: impl Layouter, + tensors: &Vec>, + _constants: &HashMap>, + _gadget_config: Rc, + layer_config: &LayerConfig, + ) -> Result>, Error> { + let inp = &tensors[0]; + let shape = layer_config.out_shapes[0].clone(); + + println!("Reshape: {:?} -> {:?}", inp.shape(), shape); + let flat = inp.iter().map(|x| x.clone()).collect(); + let out = Array::from_shape_vec(shape, flat).unwrap(); + Ok(vec![out]) + } +} + +impl GadgetConsumer for ReshapeChip { + fn used_gadgets(&self, _layer_params: Vec) -> Vec { + vec![] + } +} diff --git a/mnist_zkml/src/layers/shape/resize_nn.rs b/mnist_zkml/src/layers/shape/resize_nn.rs new file mode 100644 index 0000000..2ada4a4 --- /dev/null +++ b/mnist_zkml/src/layers/shape/resize_nn.rs @@ -0,0 +1,56 @@ +use std::{collections::HashMap, rc::Rc}; + +use halo2_proofs::{circuit::Layouter, halo2curves::ff::PrimeField, plonk::Error}; +use ndarray::{Array, IxDyn}; + +use crate::{ + gadgets::gadget::GadgetConfig, + layers::layer::{AssignedTensor, CellRc, GadgetConsumer}, +}; + +use super::super::layer::{Layer, LayerConfig}; + +pub struct ResizeNNChip {} + +// TODO: this does not work in general +impl Layer for ResizeNNChip { + fn forward( + &self, + _layouter: impl Layouter, + tensors: &Vec>, + _constants: &HashMap>, + _gadget_config: Rc, + layer_config: &LayerConfig, + ) -> Result>, Error> { + let inp = &tensors[0]; + let output_shape = layer_config.out_shapes[0].clone(); + + assert_eq!(inp.ndim(), 4); + assert_eq!(inp.shape()[0], 1); + assert_eq!(inp.shape()[3], output_shape[3]); + + let mut flat = vec![]; + // Do nearest neighbor interpolation over batch, h, w, c + // The interpolation is over h and w + for b in 0..inp.shape()[0] { + for h in 0..output_shape[1] { + let h_in = (h as f64 * (inp.shape()[1] as f64 / output_shape[1] as f64)) as usize; + for w in 0..output_shape[2] { + let w_in = (w as f64 * (inp.shape()[2] as f64 / output_shape[2] as f64)) as usize; + for c in 0..inp.shape()[3] { + flat.push(inp[[b, h_in, w_in, c]].clone()); + } + } + } + } + + let outp = Array::from_shape_vec(IxDyn(&output_shape), flat).unwrap(); + Ok(vec![outp]) + } +} + +impl GadgetConsumer for ResizeNNChip { + fn used_gadgets(&self, _layer_params: Vec) -> Vec { + vec![] + } +} diff --git a/mnist_zkml/src/layers/shape/rotate.rs b/mnist_zkml/src/layers/shape/rotate.rs new file mode 100644 index 0000000..618c526 --- /dev/null +++ b/mnist_zkml/src/layers/shape/rotate.rs @@ -0,0 +1,74 @@ +// TODO: The implementation is not ideal. + +use std::{collections::HashMap, rc::Rc}; + +use halo2_proofs::{circuit::Layouter, halo2curves::ff::PrimeField, plonk::Error}; + +use crate::{ + gadgets::gadget::GadgetConfig, + layers::layer::{AssignedTensor, CellRc, GadgetConsumer}, +}; + +use super::super::layer::{Layer, LayerConfig}; + +pub struct RotateChip {} + +// Example: +// input: +// [1 2 3 4] +// [5 6 7 8] +// +// params: [1] -- flip axis 1 only +// output: +// [4 3 2 1] +// [8 7 6 5] +impl Layer for RotateChip { + fn forward( + &self, + _layouter: impl Layouter, + tensors: &Vec>, + _constants: &HashMap>, + _gadget_config: Rc, + layer_config: &LayerConfig, + ) -> Result>, Error> { + let inp = &tensors[0]; + let params = &layer_config.layer_params; + + assert!(inp.shape().len() == 4); + + let mut flip = vec![false; 4]; + for p in params { + flip[*p as usize] = true; + } + let shape = inp.shape(); + + println!("Rotate: {:?} -> {:?}", inp.shape(), shape); + + let mut out = inp.clone(); + + for i in 0..shape[0] { + for j in 0..shape[1] { + for k in 0..shape[2] { + for l in 0..shape[3] { + let [ix, jx, kx, lx]: [usize; 4] = [i, j, k, l] + .iter() + .enumerate() + .map(|(idx, x)| if flip[idx] { shape[idx] - 1 - *x } else { *x }) + .collect::>() + .try_into() + .unwrap(); + out[[ix, jx, kx, lx]] = inp[[i, j, k, l]].clone(); + } + } + } + } + + Ok(vec![out]) + } +} + +impl GadgetConsumer for RotateChip { + fn used_gadgets(&self, _layer_params: Vec) -> Vec { + vec![] + } +} diff --git a/mnist_zkml/src/layers/shape/slice.rs b/mnist_zkml/src/layers/shape/slice.rs new file mode 100644 index 0000000..6cfd653 --- /dev/null +++ b/mnist_zkml/src/layers/shape/slice.rs @@ -0,0 +1,48 @@ +use std::{collections::HashMap, rc::Rc}; + +use halo2_proofs::{circuit::Layouter, halo2curves::ff::PrimeField, plonk::Error}; +use ndarray::Slice; + +use crate::{ + gadgets::gadget::{GadgetConfig, GadgetType}, + layers::layer::{AssignedTensor, CellRc, GadgetConsumer}, +}; + +use super::super::layer::{Layer, LayerConfig}; + +pub struct SliceChip {} + +impl Layer for SliceChip { + fn forward( + &self, + _layouter: impl Layouter, + tensors: &Vec>, + _constants: &HashMap>, + _gadget_config: Rc, + layer_config: &LayerConfig, + ) -> Result>, Error> { + let params = &layer_config.layer_params; + assert_eq!(params.len() % 2, 0); + let num_axes = params.len() / 2; + let starts = ¶ms[0..num_axes]; + let sizes = ¶ms[num_axes..]; + + let inp = &tensors[0]; + let outp = inp.slice_each_axis(|ax| { + let start = starts[ax.axis.0] as usize; + let size = sizes[ax.axis.0]; + if size == -1 { + Slice::from(start..) + } else { + Slice::from(start..(start + size as usize)) + } + }); + Ok(vec![outp.to_owned()]) + } +} + +impl GadgetConsumer for SliceChip { + fn used_gadgets(&self, _layer_params: Vec) -> Vec { + vec![] + } +} diff --git a/mnist_zkml/src/layers/shape/split.rs b/mnist_zkml/src/layers/shape/split.rs new file mode 100644 index 0000000..071dfbe --- /dev/null +++ b/mnist_zkml/src/layers/shape/split.rs @@ -0,0 +1,47 @@ +use std::{collections::HashMap, rc::Rc}; + +use halo2_proofs::{circuit::Layouter, halo2curves::ff::PrimeField, plonk::Error}; +use ndarray::{Axis, Slice}; + +use crate::{ + gadgets::gadget::{GadgetConfig, GadgetType}, + layers::layer::{AssignedTensor, CellRc, GadgetConsumer}, +}; + +use super::super::layer::{Layer, LayerConfig}; + +pub struct SplitChip {} + +impl Layer for SplitChip { + fn forward( + &self, + _layouter: impl Layouter, + tensors: &Vec>, + _constants: &HashMap>, + _gadget_config: Rc, + layer_config: &LayerConfig, + ) -> Result>, Error> { + let axis = layer_config.layer_params[0] as usize; + let num_splits = layer_config.layer_params[1] as usize; + let inp = &tensors[1]; + + let mut out = vec![]; + let split_len = inp.shape()[axis] / num_splits; + for i in 0..num_splits { + let slice = inp + .slice_axis( + Axis(axis), + Slice::from((i * split_len)..((i + 1) * split_len)), + ) + .to_owned(); + out.push(slice.to_owned()); + } + Ok(out) + } +} + +impl GadgetConsumer for SplitChip { + fn used_gadgets(&self, _layer_params: Vec) -> Vec { + vec![] + } +} diff --git a/mnist_zkml/src/layers/shape/transpose.rs b/mnist_zkml/src/layers/shape/transpose.rs new file mode 100644 index 0000000..d9de1bf --- /dev/null +++ b/mnist_zkml/src/layers/shape/transpose.rs @@ -0,0 +1,52 @@ +use std::{collections::HashMap, rc::Rc}; + +use halo2_proofs::{circuit::Layouter, halo2curves::ff::PrimeField, plonk::Error}; +use ndarray::{Array, IxDyn}; + +use crate::{ + gadgets::gadget::GadgetConfig, + layers::layer::{AssignedTensor, CellRc, GadgetConsumer}, +}; + +use super::super::layer::{Layer, LayerConfig}; + +pub struct TransposeChip {} + +impl Layer for TransposeChip { + fn forward( + &self, + _layouter: impl Layouter, + tensors: &Vec>, + _constants: &HashMap>, + _gadget_config: Rc, + layer_config: &LayerConfig, + ) -> Result>, Error> { + assert_eq!(layer_config.layer_params.len() % 2, 0); + let ndim = layer_config.layer_params.len() / 2; + let inp_shape = layer_config.layer_params[0..ndim] + .to_vec() + .iter() + .map(|x| *x as usize) + .collect::>(); + let permutation = layer_config.layer_params[ndim..] + .to_vec() + .iter() + .map(|x| *x as usize) + .collect::>(); + + let inp = &tensors[0]; + // Required because of memory layout issues + let inp_flat = inp.iter().cloned().collect::>(); + let inp = Array::from_shape_vec(IxDyn(&inp_shape), inp_flat).unwrap(); + + let inp = inp.permuted_axes(IxDyn(&permutation)); + + Ok(vec![inp]) + } +} + +impl GadgetConsumer for TransposeChip { + fn used_gadgets(&self, _layer_params: Vec) -> Vec { + vec![] + } +} diff --git a/mnist_zkml/src/layers/softmax.rs b/mnist_zkml/src/layers/softmax.rs new file mode 100644 index 0000000..7a6aae8 --- /dev/null +++ b/mnist_zkml/src/layers/softmax.rs @@ -0,0 +1,212 @@ +use std::{collections::HashMap, rc::Rc, vec}; + +use halo2_proofs::{ + circuit::{AssignedCell, Layouter}, + halo2curves::ff::PrimeField, + plonk::Error, +}; +use ndarray::{s, Array, IxDyn}; + +use crate::gadgets::{ + adder::AdderChip, + gadget::{Gadget, GadgetConfig, GadgetType}, + max::MaxChip, + nonlinear::exp::ExpGadgetChip, + sub_pairs::SubPairsChip, + var_div_big3::VarDivRoundBig3Chip, +}; + +use super::layer::{AssignedTensor, CellRc, GadgetConsumer, Layer, LayerConfig}; + +#[derive(Clone, Debug)] +pub struct SoftmaxChip {} + +impl SoftmaxChip { + pub fn softmax_flat( + mut layouter: impl Layouter, + constants: &HashMap>, + inp_flat: Vec<&AssignedCell>, + gadget_config: Rc, + mask: &Vec, + ) -> Result>, Error> { + let exp_chip = ExpGadgetChip::::construct(gadget_config.clone()); + let adder_chip = AdderChip::::construct(gadget_config.clone()); + let sub_pairs_chip = SubPairsChip::::construct(gadget_config.clone()); + let max_chip = MaxChip::::construct(gadget_config.clone()); + let var_div_big_chip = VarDivRoundBig3Chip::::construct(gadget_config.clone()); + + let zero = constants.get(&0).unwrap().as_ref(); + let sf = constants + .get(&(gadget_config.scale_factor as i64)) + .unwrap() + .as_ref(); + + // Mask the input for max computation and subtraction + let inp_take = inp_flat + .iter() + .enumerate() + .filter(|(i, _)| mask[*i] == 0) // Awkwardly, 1 = take negative infinity + .map(|(_, x)| *x) + .collect::>(); + + // Compute the max + let max = max_chip + .forward( + layouter.namespace(|| format!("max")), + &vec![inp_take.clone()], + &vec![zero], + ) + .unwrap(); + let max = &max[0]; + + // Subtract the max + let max_flat = vec![max; inp_take.len()]; + let sub = sub_pairs_chip.forward( + layouter.namespace(|| format!("sub")), + &vec![inp_take, max_flat], + &vec![zero], + )?; + + let sub = sub.iter().collect::>(); + + // Compute the exp + let exp_slice = exp_chip.forward( + layouter.namespace(|| format!("exp")), + &vec![sub], + &vec![zero], + )?; + + // Compute the sum + let sum = adder_chip.forward( + layouter.namespace(|| format!("sum")), + &vec![exp_slice.iter().collect()], + &vec![zero], + )?; + let sum = sum[0].clone(); + let sum_div_sf = var_div_big_chip.forward( + layouter.namespace(|| format!("sum div sf")), + &vec![vec![&sum]], + &vec![zero, sf], + )?; + let sum_div_sf = sum_div_sf[0].clone(); + + let dived = var_div_big_chip.forward( + layouter.namespace(|| format!("div")), + &vec![exp_slice.iter().collect()], + &vec![zero, &sum_div_sf], + )?; + + // Take either zero (softmax(-inf)) or the result + let mut div_idx = 0; + let dived = mask + .iter() + .map(|x| { + if *x == 1 { + zero.clone() + } else { + let tmp = dived[div_idx].clone(); + div_idx = div_idx + 1; + tmp + } + }) + .collect(); + + Ok(dived) + } +} + +impl Layer for SoftmaxChip { + fn forward( + &self, + mut layouter: impl Layouter, + tensors: &Vec>, + constants: &HashMap>, + gadget_config: Rc, + layer_config: &LayerConfig, + ) -> Result>, Error> { + let inp = &tensors[0]; + assert!(inp.ndim() == 2 || inp.ndim() == 3 || inp.ndim() == 4); + if inp.ndim() == 4 { + assert_eq!(inp.shape()[0], 1); + } + + let inp_shape = inp.shape().iter().map(|x| *x).collect::>(); + let mask = if layer_config.layer_params.len() == 0 { + Array::from_shape_fn(IxDyn(&inp_shape), |_| 0) + } else { + let mask_shape_len = layer_config.layer_params[0] as usize; + let mask_shape = layer_config.layer_params[1..(1 + mask_shape_len)] + .iter() + .map(|x| *x as usize) + .collect::>(); + let mask = layer_config.layer_params[(1 + mask_shape_len)..].to_vec(); + let mask = Array::from_shape_vec(IxDyn(&mask_shape), mask).unwrap(); + let mask = mask.broadcast(IxDyn(&inp_shape)).unwrap().to_owned(); + mask + }; + + let shape = if inp.ndim() == 2 || inp.ndim() == 3 { + inp.shape().iter().map(|x| *x).collect::>() + } else { + vec![inp.shape()[1], inp.shape()[2], inp.shape()[3]] + }; + let inp = inp.to_owned().into_shape(shape.clone()).unwrap(); + let mask = mask.into_shape(shape.clone()).unwrap(); + + let mut outp = vec![]; + if inp.ndim() == 2 { + for i in 0..shape[0] { + let inp_slice = inp.slice(s![i, ..]); + let inp_flat = inp_slice.iter().map(|x| x.as_ref()).collect::>(); + let mask_slice = mask.slice(s![i, ..]); + let mask_flat = mask_slice.iter().map(|x| *x as i64).collect::>(); + let dived = Self::softmax_flat( + layouter.namespace(|| format!("softmax {}", i)), + constants, + inp_flat, + gadget_config.clone(), + &mask_flat, + ) + .unwrap(); + outp.extend(dived); + } + } else if inp.ndim() == 3 { + for i in 0..shape[0] { + for j in 0..shape[1] { + let inp_slice = inp.slice(s![i, j, ..]); + let inp_flat = inp_slice.iter().map(|x| x.as_ref()).collect::>(); + let mask_slice = mask.slice(s![i, j, ..]); + let mask_flat = mask_slice.iter().map(|x| *x as i64).collect::>(); + let dived = Self::softmax_flat( + layouter.namespace(|| format!("softmax {} {}", i, j)), + constants, + inp_flat, + gadget_config.clone(), + &mask_flat, + ) + .unwrap(); + outp.extend(dived); + } + } + } else { + panic!("Not implemented"); + } + + let outp = outp.into_iter().map(|x| Rc::new(x)).collect::>(); + let outp = Array::from_shape_vec(IxDyn(inp.shape()), outp).unwrap(); + Ok(vec![outp]) + } +} + +impl GadgetConsumer for SoftmaxChip { + fn used_gadgets(&self, _layer_params: Vec) -> Vec { + vec![ + GadgetType::Exp, + GadgetType::Adder, + GadgetType::VarDivRoundBig3, + GadgetType::Max, + GadgetType::SubPairs, + GadgetType::InputLookup, + ] + } +} diff --git a/mnist_zkml/src/layers/sqrt.rs b/mnist_zkml/src/layers/sqrt.rs new file mode 100644 index 0000000..ddd0aa3 --- /dev/null +++ b/mnist_zkml/src/layers/sqrt.rs @@ -0,0 +1,71 @@ +use std::{collections::HashMap, rc::Rc, vec}; + +use halo2_proofs::{circuit::Layouter, halo2curves::ff::PrimeField, plonk::Error}; +use ndarray::{Array, IxDyn}; + +use crate::gadgets::{ + gadget::{Gadget, GadgetConfig, GadgetType}, + nonlinear::sqrt::SqrtGadgetChip, +}; + +use super::layer::{AssignedTensor, CellRc, GadgetConsumer, Layer, LayerConfig}; + +#[derive(Clone, Debug)] +pub struct SqrtChip {} + +impl Layer for SqrtChip { + fn forward( + &self, + mut layouter: impl Layouter, + tensors: &Vec>, + constants: &HashMap>, + gadget_config: Rc, + layer_config: &LayerConfig, + ) -> Result>, Error> { + let inp = &tensors[0]; + let mut inp_vec = vec![]; + + let mask = &layer_config.mask; + let mut mask_map = HashMap::new(); + for i in 0..mask.len() / 2 { + mask_map.insert(mask[2 * i], mask[2 * i + 1]); + } + + let min_val = gadget_config.min_val; + let min_val = constants.get(&min_val).unwrap().as_ref(); + let max_val = gadget_config.max_val; + let max_val = constants.get(&max_val).unwrap().as_ref(); + for (i, val) in inp.iter().enumerate() { + let i = i as i64; + if mask_map.contains_key(&i) { + let mask_val = *mask_map.get(&i).unwrap(); + if mask_val == 1 { + inp_vec.push(max_val); + } else if mask_val == -1 { + inp_vec.push(min_val); + } else { + panic!(); + } + } else { + inp_vec.push(val.as_ref()); + } + } + + let zero = constants.get(&0).unwrap().as_ref(); + let sqrt_chip = SqrtGadgetChip::::construct(gadget_config.clone()); + let vec_inps = vec![inp_vec]; + let constants = vec![zero, min_val, max_val]; + let out = sqrt_chip.forward(layouter.namespace(|| "sqrt chip"), &vec_inps, &constants)?; + + let out = out.into_iter().map(|x| Rc::new(x)).collect::>(); + let out = Array::from_shape_vec(IxDyn(inp.shape()), out).unwrap(); + + Ok(vec![out]) + } +} + +impl GadgetConsumer for SqrtChip { + fn used_gadgets(&self, _layer_params: Vec) -> Vec { + vec![GadgetType::Sqrt, GadgetType::InputLookup] + } +} diff --git a/mnist_zkml/src/layers/square.rs b/mnist_zkml/src/layers/square.rs new file mode 100644 index 0000000..0ca4ba7 --- /dev/null +++ b/mnist_zkml/src/layers/square.rs @@ -0,0 +1,69 @@ +use std::{collections::HashMap, rc::Rc, vec}; + +use halo2_proofs::{circuit::Layouter, halo2curves::ff::PrimeField, plonk::Error}; +use ndarray::{Array, IxDyn}; + +use crate::gadgets::{ + gadget::{Gadget, GadgetConfig, GadgetType}, + square::SquareGadgetChip, + var_div::VarDivRoundChip, +}; + +use super::layer::{AssignedTensor, CellRc, GadgetConsumer, Layer, LayerConfig}; + +#[derive(Clone, Debug)] +pub struct SquareChip {} + +impl Layer for SquareChip { + fn forward( + &self, + mut layouter: impl Layouter, + tensors: &Vec>, + constants: &HashMap>, + gadget_config: Rc, + _layer_config: &LayerConfig, + ) -> Result>, Error> { + assert_eq!(tensors.len(), 1); + + let inp = &tensors[0]; + let zero = constants.get(&0).unwrap().as_ref(); + + let square_chip = SquareGadgetChip::::construct(gadget_config.clone()); + let inp_vec = inp.iter().map(|x| x.as_ref()).collect::>(); + let vec_inputs = vec![inp_vec]; + let single_inps = vec![zero]; + let out = square_chip.forward( + layouter.namespace(|| "square chip"), + &vec_inputs, + &single_inps, + )?; + + let var_div_chip = VarDivRoundChip::::construct(gadget_config.clone()); + let div = constants + .get(&(gadget_config.scale_factor as i64)) + .unwrap() + .as_ref(); + let single_inps = vec![zero, div]; + let out = out.iter().collect::>(); + let vec_inputs = vec![out]; + let out = var_div_chip.forward( + layouter.namespace(|| "var div chip"), + &vec_inputs, + &single_inps, + )?; + + let out = out.into_iter().map(|x| Rc::new(x)).collect::>(); + let out = Array::from_shape_vec(IxDyn(inp.shape()), out).unwrap(); + Ok(vec![out]) + } +} + +impl GadgetConsumer for SquareChip { + fn used_gadgets(&self, _layer_params: Vec) -> Vec { + vec![ + GadgetType::Square, + GadgetType::VarDivRound, + GadgetType::InputLookup, + ] + } +} diff --git a/mnist_zkml/src/layers/squared_diff.rs b/mnist_zkml/src/layers/squared_diff.rs new file mode 100644 index 0000000..7900b49 --- /dev/null +++ b/mnist_zkml/src/layers/squared_diff.rs @@ -0,0 +1,77 @@ +use std::{collections::HashMap, rc::Rc, vec}; + +use halo2_proofs::{circuit::Layouter, halo2curves::ff::PrimeField, plonk::Error}; +use ndarray::{Array, IxDyn}; + +use crate::{ + gadgets::{ + gadget::{Gadget, GadgetConfig, GadgetType}, + squared_diff::SquaredDiffGadgetChip, + var_div::VarDivRoundChip, + }, + utils::helpers::broadcast, +}; + +use super::layer::{AssignedTensor, CellRc, GadgetConsumer, Layer, LayerConfig}; + +#[derive(Clone, Debug)] +pub struct SquaredDiffChip {} + +impl Layer for SquaredDiffChip { + fn forward( + &self, + mut layouter: impl Layouter, + tensors: &Vec>, + constants: &HashMap>, + gadget_config: Rc, + _layer_config: &LayerConfig, + ) -> Result>, Error> { + assert_eq!(tensors.len(), 2); + let inp1 = &tensors[0]; + let inp2 = &tensors[1]; + // Broadcoasting allowed... can't check shapes easily + let (inp1, inp2) = broadcast(inp1, inp2); + + let zero = constants.get(&0).unwrap().as_ref(); + + let sq_diff_chip = SquaredDiffGadgetChip::::construct(gadget_config.clone()); + let inp1_vec = inp1.iter().map(|x| x.as_ref()).collect::>(); + let inp2_vec = inp2.iter().map(|x| x.as_ref()).collect::>(); + let vec_inputs = vec![inp1_vec, inp2_vec]; + let tmp_constants = vec![zero]; + let out = sq_diff_chip.forward( + layouter.namespace(|| "sq diff chip"), + &vec_inputs, + &tmp_constants, + )?; + + let var_div_chip = VarDivRoundChip::::construct(gadget_config.clone()); + let div = constants + .get(&(gadget_config.scale_factor as i64)) + .unwrap() + .as_ref(); + + let single_inputs = vec![zero, div]; + let out = out.iter().map(|x| x).collect::>(); + let out = var_div_chip.forward( + layouter.namespace(|| "sq diff div"), + &vec![out], + &single_inputs, + )?; + + let out = out.into_iter().map(|x| Rc::new(x)).collect::>(); + let out = Array::from_shape_vec(IxDyn(inp1.shape()), out).unwrap(); + + Ok(vec![out]) + } +} + +impl GadgetConsumer for SquaredDiffChip { + fn used_gadgets(&self, _layer_params: Vec) -> Vec { + vec![ + GadgetType::SquaredDiff, + GadgetType::VarDivRound, + GadgetType::InputLookup, + ] + } +} diff --git a/mnist_zkml/src/layers/tanh.rs b/mnist_zkml/src/layers/tanh.rs new file mode 100644 index 0000000..2d44365 --- /dev/null +++ b/mnist_zkml/src/layers/tanh.rs @@ -0,0 +1,45 @@ +use std::{collections::HashMap, rc::Rc, vec}; + +use halo2_proofs::{circuit::Layouter, halo2curves::ff::PrimeField, plonk::Error}; +use ndarray::{Array, IxDyn}; + +use crate::gadgets::{ + gadget::{Gadget, GadgetConfig, GadgetType}, + nonlinear::tanh::TanhGadgetChip, +}; + +use super::layer::{AssignedTensor, CellRc, GadgetConsumer, Layer, LayerConfig}; + +#[derive(Clone, Debug)] +pub struct TanhChip {} + +impl Layer for TanhChip { + fn forward( + &self, + mut layouter: impl Layouter, + tensors: &Vec>, + constants: &HashMap>, + gadget_config: Rc, + _layer_config: &LayerConfig, + ) -> Result>, Error> { + let inp = &tensors[0]; + let inp_vec = inp.iter().map(|x| x.as_ref()).collect::>(); + let zero = constants.get(&0).unwrap().as_ref(); + + let tanh_chip = TanhGadgetChip::::construct(gadget_config.clone()); + let vec_inps = vec![inp_vec]; + let constants = vec![zero]; + let out = tanh_chip.forward(layouter.namespace(|| "tanh chip"), &vec_inps, &constants)?; + + let out = out.into_iter().map(|x| Rc::new(x)).collect::>(); + let out = Array::from_shape_vec(IxDyn(inp.shape()), out).unwrap(); + + Ok(vec![out]) + } +} + +impl GadgetConsumer for TanhChip { + fn used_gadgets(&self, _layer_params: Vec) -> Vec { + vec![GadgetType::Tanh, GadgetType::InputLookup] + } +} diff --git a/mnist_zkml/src/layers/update.rs b/mnist_zkml/src/layers/update.rs new file mode 100644 index 0000000..73ec0ab --- /dev/null +++ b/mnist_zkml/src/layers/update.rs @@ -0,0 +1,51 @@ +use std::{collections::HashMap, rc::Rc, vec}; + +use halo2_proofs::{circuit::Layouter, halo2curves::ff::PrimeField, plonk::Error}; +use ndarray::{Array, IxDyn}; + +use crate::gadgets::{ + gadget::{Gadget, GadgetConfig, GadgetType}, + update::UpdateGadgetChip, +}; + +use super::layer::{AssignedTensor, CellRc, GadgetConsumer, Layer, LayerConfig}; + +#[derive(Clone, Debug)] +pub struct UpdateChip {} + +impl Layer for UpdateChip { + fn forward( + &self, + mut layouter: impl Layouter, + tensors: &Vec>, + constants: &HashMap>, + gadget_config: Rc, + _layer_config: &LayerConfig, + ) -> Result>, Error> { + let w = &tensors[0]; + let dw = &tensors[1]; + + let zero = constants.get(&0).unwrap().as_ref(); + let update_chip = UpdateGadgetChip::::construct((*gadget_config).clone()); + + let flattened_w = w.into_iter().map(|x| (**x).clone()).collect::>(); + let flattened_dw = dw.into_iter().map(|x| (**x).clone()).collect::>(); + let flattened_w_ref = flattened_w.iter().collect::>(); + let flattened_dw_ref = flattened_dw.iter().collect::>(); + + let vec_inps = vec![flattened_w_ref, flattened_dw_ref]; + let constants = vec![zero]; + let out = update_chip.forward(layouter.namespace(|| "update chip"), &vec_inps, &constants)?; + + let out = out.into_iter().map(|x| Rc::new(x)).collect::>(); + let out = Array::from_shape_vec(IxDyn(w.shape()), out).unwrap(); + + Ok(vec![out]) + } +} + +impl GadgetConsumer for UpdateChip { + fn used_gadgets(&self, _layer_params: Vec) -> Vec { + vec![GadgetType::Update] + } +} diff --git a/mnist_zkml/src/lib.rs b/mnist_zkml/src/lib.rs new file mode 100644 index 0000000..4ad95f1 --- /dev/null +++ b/mnist_zkml/src/lib.rs @@ -0,0 +1,7 @@ +#![feature(int_roundings)] + +pub mod commitments; +pub mod gadgets; +pub mod layers; +pub mod model; +pub mod utils; diff --git a/mnist_zkml/src/model.rs b/mnist_zkml/src/model.rs new file mode 100644 index 0000000..d88b59e --- /dev/null +++ b/mnist_zkml/src/model.rs @@ -0,0 +1,836 @@ +use std::{ + collections::{BTreeMap, BTreeSet, HashMap}, + marker::PhantomData, + rc::Rc, + sync::{Arc, Mutex}, +}; + +use halo2_proofs::{ + circuit::{Layouter, SimpleFloorPlanner, Value}, + halo2curves::ff::{FromUniformBytes, PrimeField}, + plonk::{Advice, Circuit, Column, ConstraintSystem, Error, Instance}, +}; +use lazy_static::lazy_static; +use ndarray::{Array, IxDyn}; +use num_bigint::BigUint; + +use crate::{ + commitments::{ + commit::Commit, + packer::PackerChip, + poseidon_commit::{PoseidonCommitChip, L, RATE, WIDTH}, + }, + gadgets::{ + add_pairs::AddPairsChip, + adder::AdderChip, + bias_div_round_relu6::BiasDivRoundRelu6Chip, + dot_prod::DotProductChip, + gadget::{Gadget, GadgetConfig, GadgetType}, + input_lookup::InputLookupChip, + max::MaxChip, + mul_pairs::MulPairsChip, + nonlinear::{exp::ExpGadgetChip, pow::PowGadgetChip, relu::ReluChip, tanh::TanhGadgetChip}, + nonlinear::{logistic::LogisticGadgetChip, rsqrt::RsqrtGadgetChip, sqrt::SqrtGadgetChip}, + sqrt_big::SqrtBigChip, + square::SquareGadgetChip, + squared_diff::SquaredDiffGadgetChip, + sub_pairs::SubPairsChip, + update::UpdateGadgetChip, + var_div::VarDivRoundChip, + var_div_big::VarDivRoundBigChip, + var_div_big3::VarDivRoundBig3Chip, + }, + layers::{ + arithmetic::{add::AddChip, div_var::DivVarChip, mul::MulChip, sub::SubChip}, + avg_pool_2d::AvgPool2DChip, + batch_mat_mul::BatchMatMulChip, + conv2d::Conv2DChip, + dag::{DAGLayerChip, DAGLayerConfig}, + fully_connected::{FullyConnectedChip, FullyConnectedConfig}, + layer::{AssignedTensor, CellRc, GadgetConsumer, LayerConfig, LayerType}, + logistic::LogisticChip, + max_pool_2d::MaxPool2DChip, + mean::MeanChip, + noop::NoopChip, + pow::PowChip, + rsqrt::RsqrtChip, + shape::{ + broadcast::BroadcastChip, concatenation::ConcatenationChip, mask_neg_inf::MaskNegInfChip, + pack::PackChip, pad::PadChip, permute::PermuteChip, reshape::ReshapeChip, + resize_nn::ResizeNNChip, rotate::RotateChip, slice::SliceChip, split::SplitChip, + transpose::TransposeChip, + }, + softmax::SoftmaxChip, + sqrt::SqrtChip, + square::SquareChip, + squared_diff::SquaredDiffChip, + tanh::TanhChip, + update::UpdateChip, + }, + utils::{ + helpers::{convert_to_bigint, RAND_START_IDX}, + loader::{load_model_msgpack, ModelMsgpack}, + }, +}; + +lazy_static! { + pub static ref GADGET_CONFIG: Mutex = Mutex::new(GadgetConfig::default()); + pub static ref PUBLIC_VALS: Mutex> = Mutex::new(vec![]); +} + +#[derive(Clone, Debug, Default)] +pub struct ModelCircuit { + pub used_gadgets: Arc>, + pub dag_config: DAGLayerConfig, + pub tensors: BTreeMap>, + pub commit_before: Vec>, + pub commit_after: Vec>, + pub k: usize, + pub bits_per_elem: usize, + pub inp_idxes: Vec, + pub num_random: i64, +} + +#[derive(Clone, Debug)] +pub struct ModelConfig> { + pub gadget_config: Rc, + pub public_col: Column, + pub hasher: Option>, + pub _marker: PhantomData, +} + +impl> ModelCircuit { + pub fn assign_tensors_map( + &self, + mut layouter: impl Layouter, + columns: &Vec>, + tensors: &BTreeMap>, + ) -> Result>, Error> { + let tensors = layouter.assign_region( + || "asssignment", + |mut region| { + let mut cell_idx = 0; + let mut assigned_tensors = BTreeMap::new(); + + for (tensor_idx, tensor) in tensors.iter() { + let mut flat = vec![]; + for val in tensor.iter() { + let row_idx = cell_idx / columns.len(); + let col_idx = cell_idx % columns.len(); + let cell = region + .assign_advice( + || "assignment", + columns[col_idx], + row_idx, + || Value::known(*val), + ) + .unwrap(); + flat.push(Rc::new(cell)); + cell_idx += 1; + } + let tensor = Array::from_shape_vec(tensor.shape(), flat).unwrap(); + assigned_tensors.insert(*tensor_idx, tensor); + } + + Ok(assigned_tensors) + }, + )?; + + Ok(tensors) + } + + pub fn tensor_map_to_vec( + &self, + tensor_map: &BTreeMap, IxDyn>>, + ) -> Result>, Error> { + let smallest_tensor = tensor_map + .iter() + .min_by_key(|(_, tensor)| tensor.len()) + .unwrap() + .1; + let max_tensor_key = tensor_map + .iter() + .max_by_key(|(key, _)| *key) + .unwrap() + .0 + .clone(); + let mut tensors = vec![]; + for i in 0..max_tensor_key + 1 { + let tensor = tensor_map.get(&i).unwrap_or(smallest_tensor); + tensors.push(tensor.clone()); + } + + Ok(tensors) + } + + pub fn assign_tensors_vec( + &self, + mut layouter: impl Layouter, + columns: &Vec>, + tensors: &BTreeMap>, + ) -> Result>, Error> { + let tensor_map = self + .assign_tensors_map( + layouter.namespace(|| "assign_tensors_map"), + columns, + tensors, + ) + .unwrap(); + self.tensor_map_to_vec(&tensor_map) + } + + pub fn assign_constants( + &self, + mut layouter: impl Layouter, + gadget_config: Rc, + ) -> Result>, Error> { + let sf = gadget_config.scale_factor; + let min_val = gadget_config.min_val; + let max_val = gadget_config.max_val; + + let constants = layouter.assign_region( + || "constants", + |mut region| { + let mut constants: HashMap> = HashMap::new(); + + let vals = vec![0 as i64, 1, sf as i64, min_val, max_val]; + let shift_val_i64 = -min_val * 2; // FIXME + let shift_val_f = F::from(shift_val_i64 as u64); + for (i, val) in vals.iter().enumerate() { + let cell = region.assign_fixed( + || format!("constant_{}", i), + gadget_config.fixed_columns[0], + i, + || Value::known(F::from((val + shift_val_i64) as u64) - shift_val_f), + )?; + constants.insert(*val, Rc::new(cell)); + } + + // TODO: I've made some very bad life decisions + // TOOD: this needs to be a random oracle + let r_base = F::from(0x123456789abcdef); + let mut r = r_base.clone(); + for i in 0..self.num_random { + let rand = region.assign_fixed( + || format!("rand_{}", i), + gadget_config.fixed_columns[0], + constants.len(), + || Value::known(r), + )?; + r = r * r_base; + constants.insert(RAND_START_IDX + (i as i64), Rc::new(rand)); + } + + Ok(constants) + }, + )?; + Ok(constants) + } + + // TODO: for some horrifying reason, assigning to fixed columns causes everything to blow up + // Currently get around this by assigning to advice columns + // This is secure because of the equality checks but EXTREMELY STUPID + pub fn assign_constants2( + &self, + mut layouter: impl Layouter, + gadget_config: Rc, + fixed_constants: &HashMap>, + ) -> Result>, Error> { + let sf = gadget_config.scale_factor; + let min_val = gadget_config.min_val; + let max_val = gadget_config.max_val; + + let constants = layouter.assign_region( + || "constants", + |mut region| { + let mut constants: HashMap> = HashMap::new(); + + let vals = vec![0 as i64, 1, sf as i64, min_val, max_val]; + let shift_val_i64 = -min_val * 2; // FIXME + let shift_val_f = F::from(shift_val_i64 as u64); + for (i, val) in vals.iter().enumerate() { + let assignment_idx = i as usize; + let row_idx = assignment_idx / gadget_config.columns.len(); + let col_idx = assignment_idx % gadget_config.columns.len(); + let cell = region.assign_advice( + || format!("constant_{}", i), + gadget_config.columns[col_idx], + row_idx, + || Value::known(F::from((val + shift_val_i64) as u64) - shift_val_f), + )?; + constants.insert(*val, Rc::new(cell)); + } + + // TODO: I've made some very bad life decisions + // TOOD: this needs to be a random oracle + let r_base = F::from(0x123456789abcdef); + let mut r = r_base.clone(); + for i in 0..self.num_random { + let assignment_idx = constants.len(); + let row_idx = assignment_idx / gadget_config.columns.len(); + let col_idx = assignment_idx % gadget_config.columns.len(); + let rand = region.assign_advice( + || format!("rand_{}", i), + gadget_config.columns[col_idx], + row_idx, + || Value::known(r), + )?; + r = r * r_base; + constants.insert(RAND_START_IDX + (i as i64), Rc::new(rand)); + } + + for (k, v) in fixed_constants.iter() { + let v2 = constants.get(k).unwrap(); + region.constrain_equal(v.cell(), v2.cell()).unwrap(); + } + Ok(constants) + }, + )?; + Ok(constants) + } + + pub fn generate_from_file(config_file: &str, inp_file: &str) -> ModelCircuit { + let config = load_model_msgpack(config_file, inp_file); + Self::generate_from_msgpack(config, true) + } + + pub fn generate_from_msgpack(config: ModelMsgpack, panic_empty_tensor: bool) -> ModelCircuit { + let to_field = |x: i64| { + let bias = 1 << 31; + let x_pos = x + bias; + F::from(x_pos as u64) - F::from(bias as u64) + }; + + let match_layer = |x: &str| match x { + "AveragePool2D" => LayerType::AvgPool2D, + "Add" => LayerType::Add, + "BatchMatMul" => LayerType::BatchMatMul, + "Broadcast" => LayerType::Broadcast, + "Concatenation" => LayerType::Concatenation, + "Conv2D" => LayerType::Conv2D, + "Div" => LayerType::DivFixed, // TODO: rename to DivFixed + "DivVar" => LayerType::DivVar, + "FullyConnected" => LayerType::FullyConnected, + "Logistic" => LayerType::Logistic, + "MaskNegInf" => LayerType::MaskNegInf, + "MaxPool2D" => LayerType::MaxPool2D, + "Mean" => LayerType::Mean, + "Mul" => LayerType::Mul, + "Noop" => LayerType::Noop, + "Pack" => LayerType::Pack, + "Pad" => LayerType::Pad, + "Pow" => LayerType::Pow, + "Permute" => LayerType::Permute, + "Reshape" => LayerType::Reshape, + "ResizeNearestNeighbor" => LayerType::ResizeNN, + "Rotate" => LayerType::Rotate, + "Rsqrt" => LayerType::Rsqrt, + "Slice" => LayerType::Slice, + "Softmax" => LayerType::Softmax, + "Split" => LayerType::Split, + "Sqrt" => LayerType::Sqrt, + "Square" => LayerType::Square, + "SquaredDifference" => LayerType::SquaredDifference, + "Sub" => LayerType::Sub, + "Tanh" => LayerType::Tanh, + "Transpose" => LayerType::Transpose, + "Update" => LayerType::Update, + _ => panic!("unknown op: {}", x), + }; + + let mut tensors = BTreeMap::new(); + for flat in config.tensors { + let value_flat = flat.data.iter().map(|x| to_field(*x)).collect::>(); + let shape = flat.shape.iter().map(|x| *x as usize).collect::>(); + let num_el: usize = shape.iter().product(); + if panic_empty_tensor && num_el != value_flat.len() { + panic!("tensor shape and data length mismatch"); + } + if num_el == value_flat.len() { + let tensor = Array::from_shape_vec(IxDyn(&shape), value_flat).unwrap(); + tensors.insert(flat.idx, tensor); + } else { + // Do nothing here since we're loading the config + }; + } + + let i64_to_usize = |x: &Vec| x.iter().map(|x| *x as usize).collect::>(); + + let mut used_gadgets = BTreeSet::new(); + + let dag_config = { + let ops = config + .layers + .iter() + .map(|layer| { + let layer_type = match_layer(&layer.layer_type); + let layer_gadgets = match layer_type { + LayerType::Add => Box::new(AddChip {}) as Box, + LayerType::AvgPool2D => Box::new(AvgPool2DChip {}) as Box, + LayerType::BatchMatMul => Box::new(BatchMatMulChip {}) as Box, + LayerType::Broadcast => Box::new(BroadcastChip {}) as Box, + LayerType::Concatenation => Box::new(ConcatenationChip {}) as Box, + LayerType::DivFixed => Box::new(ConcatenationChip {}) as Box, + LayerType::DivVar => Box::new(DivVarChip {}) as Box, + LayerType::Conv2D => Box::new(Conv2DChip { + config: LayerConfig::default(), + _marker: PhantomData::, + }) as Box, + LayerType::FullyConnected => Box::new(FullyConnectedChip { + config: FullyConnectedConfig { normalize: true }, + _marker: PhantomData::, + }) as Box, + LayerType::Logistic => Box::new(LogisticChip {}) as Box, + LayerType::MaskNegInf => Box::new(MaskNegInfChip {}) as Box, + LayerType::MaxPool2D => Box::new(MaxPool2DChip { + marker: PhantomData::, + }) as Box, + LayerType::Mean => Box::new(MeanChip {}) as Box, + LayerType::Mul => Box::new(MulChip {}) as Box, + LayerType::Noop => Box::new(NoopChip {}) as Box, + LayerType::Pack => Box::new(PackChip {}) as Box, + LayerType::Pad => Box::new(PadChip {}) as Box, + LayerType::Pow => Box::new(PowChip {}) as Box, + LayerType::Permute => Box::new(PermuteChip {}) as Box, + LayerType::Reshape => Box::new(ReshapeChip {}) as Box, + LayerType::ResizeNN => Box::new(ResizeNNChip {}) as Box, + LayerType::Rotate => Box::new(RotateChip {}) as Box, + LayerType::Rsqrt => Box::new(RsqrtChip {}) as Box, + LayerType::Slice => Box::new(SliceChip {}) as Box, + LayerType::Softmax => Box::new(SoftmaxChip {}) as Box, + LayerType::Split => Box::new(SplitChip {}) as Box, + LayerType::Sqrt => Box::new(SqrtChip {}) as Box, + LayerType::Square => Box::new(SquareChip {}) as Box, + LayerType::SquaredDifference => Box::new(SquaredDiffChip {}) as Box, + LayerType::Sub => Box::new(SubChip {}) as Box, + LayerType::Tanh => Box::new(TanhChip {}) as Box, + LayerType::Transpose => Box::new(TransposeChip {}) as Box, + LayerType::Update => Box::new(UpdateChip {}) as Box, + } + .used_gadgets(layer.params.clone()); + for gadget in layer_gadgets { + used_gadgets.insert(gadget); + } + + LayerConfig { + layer_type, + layer_params: layer.params.clone(), + inp_shapes: layer.inp_shapes.iter().map(|x| i64_to_usize(x)).collect(), + out_shapes: layer.out_shapes.iter().map(|x| i64_to_usize(x)).collect(), + mask: layer.mask.clone(), + } + }) + .collect::>(); + let inp_idxes = config + .layers + .iter() + .map(|layer| i64_to_usize(&layer.inp_idxes)) + .collect::>(); + let out_idxes = config + .layers + .iter() + .map(|layer| i64_to_usize(&layer.out_idxes)) + .collect::>(); + let final_out_idxes = config + .out_idxes + .iter() + .map(|x| *x as usize) + .collect::>(); + DAGLayerConfig { + inp_idxes, + out_idxes, + ops, + final_out_idxes, + } + }; + + // The input lookup is always used + used_gadgets.insert(GadgetType::InputLookup); + let used_gadgets = Arc::new(used_gadgets); + let gadget = &GADGET_CONFIG; + let cloned_gadget = gadget.lock().unwrap().clone(); + *gadget.lock().unwrap() = GadgetConfig { + scale_factor: config.global_sf as u64, + shift_min_val: -(config.global_sf * config.global_sf * (1 << 17)), + div_outp_min_val: -(1 << (config.k - 1)), + min_val: -(1 << (config.k - 1)), + max_val: (1 << (config.k - 1)) - 10, + k: config.k as usize, + num_rows: (1 << config.k) - 10 + 1, + num_cols: config.num_cols as usize, + used_gadgets: used_gadgets.clone(), + commit_before: config.commit_before.clone().unwrap_or(vec![]), + commit_after: config.commit_after.clone().unwrap_or(vec![]), + use_selectors: config.use_selectors.unwrap_or(true), + num_bits_per_elem: config.bits_per_elem.unwrap_or(config.k), + ..cloned_gadget + }; + + ModelCircuit { + tensors, + dag_config, + used_gadgets, + k: config.k as usize, + bits_per_elem: config.bits_per_elem.unwrap_or(config.k) as usize, + inp_idxes: config.inp_idxes.clone(), + commit_after: config.commit_after.unwrap_or(vec![]), + commit_before: config.commit_before.unwrap_or(vec![]), + num_random: config.num_random.unwrap_or(0), + } + } + + pub fn assign_and_commit( + &self, + mut layouter: impl Layouter, + constants: &HashMap>, + config: &ModelConfig, + tensors: &BTreeMap>, + ) -> (BTreeMap>, CellRc) { + let num_bits = self.bits_per_elem; + let packer_config = PackerChip::::construct(num_bits, config.gadget_config.as_ref()); + let packer_chip = PackerChip:: { + config: packer_config, + }; + let (tensor_map, packed) = packer_chip + .assign_and_pack( + layouter.namespace(|| "packer"), + config.gadget_config.clone(), + constants, + tensors, + ) + .unwrap(); + + let zero = constants.get(&0).unwrap().clone(); + let commit_chip = config.hasher.clone().unwrap(); + + let commitments = commit_chip + .commit( + layouter.namespace(|| "commit"), + config.gadget_config.clone(), + constants, + &packed, + zero.clone(), + ) + .unwrap(); + assert_eq!(commitments.len(), 1); + + (tensor_map, commitments[0].clone()) + } + + pub fn copy_and_commit( + &self, + mut layouter: impl Layouter, + constants: &HashMap>, + config: &ModelConfig, + tensors: &BTreeMap>, + ) -> CellRc { + let num_bits = self.bits_per_elem; + let packer_config = PackerChip::::construct(num_bits, config.gadget_config.as_ref()); + let packer_chip = PackerChip:: { + config: packer_config, + }; + let packed = packer_chip + .copy_and_pack( + layouter.namespace(|| "packer"), + config.gadget_config.clone(), + constants, + tensors, + ) + .unwrap(); + + let zero = constants.get(&0).unwrap().clone(); + let commit_chip = config.hasher.clone().unwrap(); + + let commitments = commit_chip + .commit( + layouter.namespace(|| "commit"), + config.gadget_config.clone(), + constants, + &packed, + zero.clone(), + ) + .unwrap(); + assert_eq!(commitments.len(), 1); + + commitments[0].clone() + } +} + +impl> Circuit for ModelCircuit { + type Config = ModelConfig; + type FloorPlanner = SimpleFloorPlanner; + type Params = (); + + fn without_witnesses(&self) -> Self { + todo!() + } + + fn configure(meta: &mut ConstraintSystem) -> Self::Config { + let mut gadget_config = crate::model::GADGET_CONFIG.lock().unwrap().clone(); + let columns = (0..gadget_config.num_cols) + .map(|_| meta.advice_column()) + .collect::>(); + for col in columns.iter() { + meta.enable_equality(*col); + } + gadget_config.columns = columns; + + let public_col = meta.instance_column(); + meta.enable_equality(public_col); + + gadget_config.fixed_columns = vec![meta.fixed_column()]; + meta.enable_equality(gadget_config.fixed_columns[0]); + + // The input lookup is always loaded + gadget_config = InputLookupChip::::configure(meta, gadget_config); + + let used_gadgets = gadget_config.used_gadgets.clone(); + for gadget_type in used_gadgets.iter() { + gadget_config = match gadget_type { + GadgetType::AddPairs => AddPairsChip::::configure(meta, gadget_config), + GadgetType::Adder => AdderChip::::configure(meta, gadget_config), + GadgetType::BiasDivRoundRelu6 => BiasDivRoundRelu6Chip::::configure(meta, gadget_config), + GadgetType::BiasDivFloorRelu6 => panic!(), + GadgetType::DotProduct => DotProductChip::::configure(meta, gadget_config), + GadgetType::Exp => ExpGadgetChip::::configure(meta, gadget_config), + GadgetType::Logistic => LogisticGadgetChip::::configure(meta, gadget_config), + GadgetType::Max => MaxChip::::configure(meta, gadget_config), + GadgetType::MulPairs => MulPairsChip::::configure(meta, gadget_config), + GadgetType::Pow => PowGadgetChip::::configure(meta, gadget_config), + GadgetType::Relu => ReluChip::::configure(meta, gadget_config), + GadgetType::Rsqrt => RsqrtGadgetChip::::configure(meta, gadget_config), + GadgetType::Sqrt => SqrtGadgetChip::::configure(meta, gadget_config), + GadgetType::SqrtBig => SqrtBigChip::::configure(meta, gadget_config), + GadgetType::Square => SquareGadgetChip::::configure(meta, gadget_config), + GadgetType::SquaredDiff => SquaredDiffGadgetChip::::configure(meta, gadget_config), + GadgetType::SubPairs => SubPairsChip::::configure(meta, gadget_config), + GadgetType::Tanh => TanhGadgetChip::::configure(meta, gadget_config), + GadgetType::VarDivRound => VarDivRoundChip::::configure(meta, gadget_config), + GadgetType::VarDivRoundBig => VarDivRoundBigChip::::configure(meta, gadget_config), + GadgetType::VarDivRoundBig3 => VarDivRoundBig3Chip::::configure(meta, gadget_config), + GadgetType::InputLookup => gadget_config, // This is always loaded + GadgetType::Update => UpdateGadgetChip::::configure(meta, gadget_config), + GadgetType::Packer => panic!(), + }; + } + + let hasher = if gadget_config.commit_before.len() + gadget_config.commit_after.len() > 0 { + let packer_config = + PackerChip::::construct(gadget_config.num_bits_per_elem as usize, &gadget_config); + gadget_config = PackerChip::::configure(meta, packer_config, gadget_config); + + // TODO + let input = gadget_config.columns[0..L].try_into().unwrap(); + let state = gadget_config.columns[L..L + WIDTH].try_into().unwrap(); + let partial_sbox = gadget_config.columns[L + WIDTH].into(); + Some(PoseidonCommitChip::::configure( + meta, + input, + state, + partial_sbox, + )) + } else { + None + }; + + ModelConfig { + gadget_config: gadget_config.into(), + public_col, + hasher, + _marker: PhantomData, + } + } + + fn synthesize(&self, config: Self::Config, mut layouter: impl Layouter) -> Result<(), Error> { + // Assign tables + let gadget_rc: Rc = config.gadget_config.clone().into(); + for gadget in self.used_gadgets.iter() { + match gadget { + GadgetType::AddPairs => { + let chip = AddPairsChip::::construct(gadget_rc.clone()); + chip.load_lookups(layouter.namespace(|| "add pairs lookup"))?; + } + GadgetType::Adder => { + let chip = AdderChip::::construct(gadget_rc.clone()); + chip.load_lookups(layouter.namespace(|| "adder lookup"))?; + } + GadgetType::BiasDivRoundRelu6 => { + let chip = BiasDivRoundRelu6Chip::::construct(gadget_rc.clone()); + chip.load_lookups(layouter.namespace(|| "bias div round relu6 lookup"))?; + } + GadgetType::DotProduct => { + let chip = DotProductChip::::construct(gadget_rc.clone()); + chip.load_lookups(layouter.namespace(|| "dot product lookup"))?; + } + GadgetType::VarDivRound => { + let chip = VarDivRoundChip::::construct(gadget_rc.clone()); + chip.load_lookups(layouter.namespace(|| "var div lookup"))?; + } + GadgetType::Pow => { + let chip = PowGadgetChip::::construct(gadget_rc.clone()); + chip.load_lookups(layouter.namespace(|| "pow lookup"))?; + } + GadgetType::Relu => { + let chip = ReluChip::::construct(gadget_rc.clone()); + chip.load_lookups(layouter.namespace(|| "relu lookup"))?; + } + GadgetType::Rsqrt => { + let chip = RsqrtGadgetChip::::construct(gadget_rc.clone()); + chip.load_lookups(layouter.namespace(|| "rsqrt lookup"))?; + } + GadgetType::Sqrt => { + let chip = SqrtGadgetChip::::construct(gadget_rc.clone()); + chip.load_lookups(layouter.namespace(|| "sqrt lookup"))?; + } + GadgetType::Tanh => { + let chip = TanhGadgetChip::::construct(gadget_rc.clone()); + chip.load_lookups(layouter.namespace(|| "tanh lookup"))?; + } + GadgetType::Exp => { + let chip = ExpGadgetChip::::construct(gadget_rc.clone()); + chip.load_lookups(layouter.namespace(|| "exp lookup"))?; + } + GadgetType::Logistic => { + let chip = LogisticGadgetChip::::construct(gadget_rc.clone()); + chip.load_lookups(layouter.namespace(|| "logistic lookup"))?; + } + GadgetType::InputLookup => { + let chip = InputLookupChip::::construct(gadget_rc.clone()); + chip.load_lookups(layouter.namespace(|| "input lookup"))?; + } + GadgetType::VarDivRoundBig => {} + GadgetType::VarDivRoundBig3 => {} + GadgetType::Max => {} + GadgetType::MulPairs => {} + GadgetType::SqrtBig => {} + GadgetType::Square => {} + GadgetType::SquaredDiff => {} + GadgetType::SubPairs => {} + GadgetType::Update => {} + _ => panic!("unsupported gadget {:?}", gadget), + } + } + + // Assign weights and constants + let constants_base = self + .assign_constants( + layouter.namespace(|| "constants"), + config.gadget_config.clone(), + ) + .unwrap(); + // Some halo2 cancer + let constants = self + .assign_constants2( + layouter.namespace(|| "constants 2"), + config.gadget_config.clone(), + &constants_base, + ) + .unwrap(); + + let mut commitments = vec![]; + let tensors = if self.commit_before.len() > 0 { + // Commit to the tensors before the DAG + let mut tensor_map = BTreeMap::new(); + let mut ignore_idxes: Vec = vec![]; + for commit_idxes in self.commit_before.iter() { + let to_commit = BTreeMap::from_iter( + commit_idxes + .iter() + .map(|idx| (*idx, self.tensors.get(idx).unwrap().clone())), + ); + let (mut committed_tensors, commitment) = self.assign_and_commit( + layouter.namespace(|| "commit"), + &constants, + &config, + &to_commit, + ); + commitments.push(commitment); + tensor_map.append(&mut committed_tensors); + ignore_idxes.extend(commit_idxes.iter()); + } + + // Assign the remainder of the tensors + let mut assign_map = BTreeMap::new(); + for (idx, tensor) in self.tensors.iter() { + if ignore_idxes.contains(idx) { + continue; + } + assign_map.insert(*idx, tensor.clone()); + } + let mut remainder_tensor_map = self + .assign_tensors_map( + layouter.namespace(|| "assignment"), + &config.gadget_config.columns, + &assign_map, + ) + .unwrap(); + + // Merge the two maps + tensor_map.append(&mut remainder_tensor_map); + + // Return the tensors + self.tensor_map_to_vec(&tensor_map).unwrap() + } else { + self + .assign_tensors_vec( + layouter.namespace(|| "assignment"), + &config.gadget_config.columns, + &self.tensors, + ) + .unwrap() + }; + + // Perform the dag + let dag_chip = DAGLayerChip::::construct(self.dag_config.clone()); + let (final_tensor_map, result) = dag_chip.forward( + layouter.namespace(|| "dag"), + &tensors, + &constants, + config.gadget_config.clone(), + &LayerConfig::default(), + )?; + + if self.commit_after.len() > 0 { + for commit_idxes in self.commit_after.iter() { + let to_commit = BTreeMap::from_iter(commit_idxes.iter().map(|idx| { + ( + *idx, + final_tensor_map.get(&(*idx as usize)).unwrap().clone(), + ) + })); + let commitment = self.copy_and_commit( + layouter.namespace(|| "commit"), + &constants, + &config, + &to_commit, + ); + commitments.push(commitment); + } + } + + let mut pub_layouter = layouter.namespace(|| "public"); + let mut total_idx = 0; + let mut new_public_vals = vec![]; + for cell in commitments.iter() { + pub_layouter + .constrain_instance(cell.as_ref().cell(), config.public_col, total_idx) + .unwrap(); + let val = convert_to_bigint(cell.value().map(|x| x.to_owned())); + new_public_vals.push(val); + total_idx += 1; + } + for tensor in result { + for cell in tensor.iter() { + pub_layouter + .constrain_instance(cell.as_ref().cell(), config.public_col, total_idx) + .unwrap(); + let val = convert_to_bigint(cell.value().map(|x| x.to_owned())); + new_public_vals.push(val); + total_idx += 1; + } + } + *PUBLIC_VALS.lock().unwrap() = new_public_vals; + + Ok(()) + } +} diff --git a/mnist_zkml/src/utils.rs b/mnist_zkml/src/utils.rs new file mode 100644 index 0000000..08ee186 --- /dev/null +++ b/mnist_zkml/src/utils.rs @@ -0,0 +1,4 @@ +pub mod helpers; +pub mod loader; +pub mod proving_ipa; +pub mod proving_kzg; diff --git a/mnist_zkml/src/utils/helpers.rs b/mnist_zkml/src/utils/helpers.rs new file mode 100644 index 0000000..c1aec93 --- /dev/null +++ b/mnist_zkml/src/utils/helpers.rs @@ -0,0 +1,140 @@ +use halo2_proofs::{ + circuit::{AssignedCell, Value}, + halo2curves::ff::PrimeField, +}; +use ndarray::{Array, IxDyn}; +use num_bigint::BigUint; + +use crate::{gadgets::gadget::convert_to_u128, model::PUBLIC_VALS}; + +// TODO: this is very bad +pub const RAND_START_IDX: i64 = i64::MIN; +pub const NUM_RANDOMS: i64 = 20001; + +// Conversion / printing functions +pub fn convert_to_bigint(x: Value) -> BigUint { + let mut big = Default::default(); + x.map(|x| { + big = BigUint::from_bytes_le(x.to_repr().as_ref()); + }); + big +} + +pub fn convert_pos_int(x: Value) -> i128 { + let bias = 1 << 60; + let x_pos = x + Value::known(F::from(bias as u64)); + let mut outp: i128 = 0; + x_pos.map(|x| { + let x_pos = convert_to_u128(&x); + let tmp = x_pos as i128 - bias; + outp = tmp; + }); + return outp; +} + +pub fn print_pos_int(prefix: &str, x: Value, scale_factor: u64) { + let tmp = convert_pos_int(x); + let tmp_float = tmp as f64 / scale_factor as f64; + println!("{} x: {} ({})", prefix, tmp, tmp_float); +} + +pub fn print_assigned_arr( + prefix: &str, + arr: &Vec<&AssignedCell>, + scale_factor: u64, +) { + for (idx, x) in arr.iter().enumerate() { + print_pos_int( + &format!("{}[{}]", prefix, idx), + x.value().map(|x: &F| x.to_owned()), + scale_factor, + ); + } +} + +// Get the public values +pub fn get_public_values() -> Vec { + let mut public_vals = vec![]; + for val in PUBLIC_VALS.lock().unwrap().iter() { + let val = F::from_str_vartime(&val.to_str_radix(10)); + public_vals.push(val.unwrap()); + } + public_vals +} + +// Broadcast +fn shape_dominates(s1: &[usize], s2: &[usize]) -> bool { + if s1.len() != s2.len() { + return false; + } + + for (x1, x2) in s1.iter().zip(s2.iter()) { + if x1 < x2 { + return false; + } + } + + true +} + +// Precondition: s1.len() < s2.len() +fn intermediate_shape(s1: &[usize], s2: &[usize]) -> Vec { + let mut res = vec![1; s2.len() - s1.len()]; + for s in s1.iter() { + res.push(*s); + } + res +} + +fn final_shape(s1: &[usize], s2: &[usize]) -> Vec { + let mut res = vec![]; + for (x1, x2) in s1.iter().zip(s2.iter()) { + res.push(std::cmp::max(*x1, *x2)); + } + res +} + +pub fn broadcast( + x1: &Array, + x2: &Array, +) -> (Array, Array) { + if x1.shape() == x2.shape() { + return (x1.clone(), x2.clone()); + } + + if x1.ndim() == x2.ndim() { + let s1 = x1.shape(); + let s2 = x2.shape(); + if shape_dominates(s1, s2) { + return (x1.clone(), x2.broadcast(s1).unwrap().into_owned()); + } else if shape_dominates(x2.shape(), x1.shape()) { + return (x1.broadcast(s2).unwrap().into_owned(), x2.clone()); + } + } + + let (tmp1, tmp2) = if x1.ndim() < x2.ndim() { + (x1, x2) + } else { + (x2, x1) + }; + + // tmp1.ndim() < tmp2.ndim() + let s1 = tmp1.shape(); + let s2 = tmp2.shape(); + let s = intermediate_shape(s1, s2); + let final_shape = final_shape(s2, s.as_slice()); + + let tmp1 = tmp1.broadcast(s.clone()).unwrap().into_owned(); + let tmp1 = tmp1.broadcast(final_shape.as_slice()).unwrap().into_owned(); + let tmp2 = tmp2.broadcast(final_shape.as_slice()).unwrap().into_owned(); + // println!("x1: {:?} x2: {:?}", x1.shape(), x2.shape()); + // println!("s1: {:?} s2: {:?} s: {:?}", s1, s2, s); + // println!("tmp1 shape: {:?}", tmp1.shape()); + // println!("tmp2 shape: {:?}", tmp2.shape()); + + if x1.ndim() < x2.ndim() { + return (tmp1, tmp2); + } else { + return (tmp2, tmp1); + } +} diff --git a/mnist_zkml/src/utils/loader.rs b/mnist_zkml/src/utils/loader.rs new file mode 100644 index 0000000..f3a4721 --- /dev/null +++ b/mnist_zkml/src/utils/loader.rs @@ -0,0 +1,77 @@ +use std::{fs::File, io::BufReader}; + +use serde_derive::{Deserialize, Serialize}; + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct TensorMsgpack { + pub idx: i64, + pub shape: Vec, + pub data: Vec, +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct LayerMsgpack { + pub layer_type: String, + pub params: Vec, + pub inp_idxes: Vec, + pub inp_shapes: Vec>, + pub out_idxes: Vec, + pub out_shapes: Vec>, + pub mask: Vec, +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct ModelMsgpack { + pub global_sf: i64, + pub k: i64, + pub num_cols: i64, + pub inp_idxes: Vec, + pub out_idxes: Vec, + pub tensors: Vec, + pub layers: Vec, + pub use_selectors: Option, + pub commit_before: Option>>, + pub commit_after: Option>>, + pub bits_per_elem: Option, // Specifically for packing for the commitments + pub num_random: Option, +} + +pub fn load_config_msgpack(config_path: &str) -> ModelMsgpack { + let model: ModelMsgpack = { + let file = File::open(config_path).unwrap(); + let mut reader = BufReader::new(file); + rmp_serde::from_read(&mut reader).unwrap() + }; + model +} + +pub fn load_model_msgpack(config_path: &str, inp_path: &str) -> ModelMsgpack { + let mut model = load_config_msgpack(config_path); + let inp: Vec = { + let file = File::open(inp_path).unwrap(); + let mut reader = BufReader::new(file); + rmp_serde::from_read(&mut reader).unwrap() + }; + for tensor in inp { + model.tensors.push(tensor); + } + + // Default to using selectors, commit if use_selectors is not specified + if model.use_selectors.is_none() { + model.use_selectors = Some(true) + }; + if model.commit_before.is_none() { + model.commit_before = Some(vec![]) + }; + if model.commit_after.is_none() { + model.commit_after = Some(vec![]) + }; + if model.bits_per_elem.is_none() { + model.bits_per_elem = Some(model.k) + }; + if model.num_random.is_none() { + model.num_random = Some(20001) + }; + + model +} diff --git a/mnist_zkml/src/utils/proving_ipa.rs b/mnist_zkml/src/utils/proving_ipa.rs new file mode 100644 index 0000000..c226efc --- /dev/null +++ b/mnist_zkml/src/utils/proving_ipa.rs @@ -0,0 +1,126 @@ +use std::{ + fs::File, + io::{BufReader, Write}, + path::Path, + time::Instant, +}; + +use halo2_proofs::{ + dev::MockProver, + halo2curves::pasta::{EqAffine, Fp}, + plonk::{create_proof, keygen_pk, keygen_vk, verify_proof}, + poly::{ + commitment::{Params, ParamsProver}, + ipa::{ + commitment::{IPACommitmentScheme, ParamsIPA}, + multiopen::ProverIPA, + strategy::SingleStrategy, + }, + VerificationStrategy, + }, + transcript::{ + Blake2bRead, Blake2bWrite, Challenge255, TranscriptReadBuffer, TranscriptWriterBuffer, + }, +}; + +use crate::{model::ModelCircuit, utils::helpers::get_public_values}; + +pub fn get_ipa_params(params_dir: &str, degree: u32) -> ParamsIPA { + let path = format!("{}/{}.params", params_dir, degree); + let params_path = Path::new(&path); + if File::open(¶ms_path).is_err() { + let params: ParamsIPA = ParamsIPA::new(degree); + let mut buf = Vec::new(); + + params.write(&mut buf).expect("Failed to write params"); + let mut file = File::create(¶ms_path).expect("Failed to create params file"); + file + .write_all(&buf[..]) + .expect("Failed to write params to file"); + } + + let params_fs = File::open(¶ms_path).expect("couldn't load params"); + let params: ParamsIPA = + Params::read::<_>(&mut BufReader::new(params_fs)).expect("Failed to read params"); + params +} + +pub fn time_circuit_ipa(circuit: ModelCircuit) { + let rng = rand::thread_rng(); + let start = Instant::now(); + + let degree = circuit.k as u32; + let empty_circuit = circuit.clone(); + let proof_circuit = circuit; + + let params = get_ipa_params("./params_ipa", degree); + + let circuit_duration = start.elapsed(); + println!( + "Time elapsed in params construction: {:?}", + circuit_duration + ); + + let vk = keygen_vk(¶ms, &empty_circuit).unwrap(); + let vk_duration = start.elapsed(); + println!( + "Time elapsed in generating vkey: {:?}", + vk_duration - circuit_duration + ); + + let pk = keygen_pk(¶ms, vk, &empty_circuit).unwrap(); + let pk_duration = start.elapsed(); + println!( + "Time elapsed in generating pkey: {:?}", + pk_duration - vk_duration + ); + drop(empty_circuit); + + let fill_duration = start.elapsed(); + let _prover = MockProver::run(degree, &proof_circuit, vec![vec![]]).unwrap(); + let public_vals = get_public_values(); + println!( + "Time elapsed in filling circuit: {:?}", + fill_duration - pk_duration + ); + + let mut transcript = Blake2bWrite::<_, _, Challenge255<_>>::init(vec![]); + create_proof::, ProverIPA, _, _, _, _>( + ¶ms, + &pk, + &[proof_circuit], + &[&[&public_vals]], + rng, + &mut transcript, + ) + .unwrap(); + let proof = transcript.finalize(); + let proof_duration = start.elapsed(); + println!("Proving time: {:?}", proof_duration - fill_duration); + + let proof_size = { + let mut folder = std::path::PathBuf::new(); + folder.push("proof"); + let mut fd = std::fs::File::create(folder.as_path()).unwrap(); + folder.pop(); + fd.write_all(&proof).unwrap(); + fd.metadata().unwrap().len() + }; + println!("Proof size: {} bytes", proof_size); + + let strategy = SingleStrategy::new(¶ms); + let mut transcript = Blake2bRead::<_, _, Challenge255<_>>::init(&proof[..]); + assert!( + verify_proof( + ¶ms, + pk.get_vk(), + strategy, + &[&[&public_vals]], + &mut transcript + ) + .is_ok(), + "proof did not verify" + ); + let verify_duration = start.elapsed(); + println!("Verifying time: {:?}", verify_duration - proof_duration); +} diff --git a/mnist_zkml/src/utils/proving_kzg.rs b/mnist_zkml/src/utils/proving_kzg.rs new file mode 100644 index 0000000..dd44442 --- /dev/null +++ b/mnist_zkml/src/utils/proving_kzg.rs @@ -0,0 +1,206 @@ +use std::{ + fs::File, + io::{BufReader, Write}, + path::Path, + time::Instant, +}; + +use halo2_proofs::{ + dev::MockProver, + halo2curves::bn256::{Bn256, Fr, G1Affine}, + plonk::{create_proof, keygen_pk, keygen_vk, verify_proof, VerifyingKey}, + poly::{ + commitment::Params, + kzg::{ + commitment::{KZGCommitmentScheme, ParamsKZG}, + multiopen::{ProverSHPLONK, VerifierSHPLONK}, + strategy::SingleStrategy, + }, + }, + transcript::{ + Blake2bRead, Blake2bWrite, Challenge255, TranscriptReadBuffer, TranscriptWriterBuffer, + }, + SerdeFormat, +}; + +use crate::{model::ModelCircuit, utils::helpers::get_public_values}; + +pub fn get_kzg_params(params_dir: &str, degree: u32) -> ParamsKZG { + let rng = rand::thread_rng(); + let path = format!("{}/{}.params", params_dir, degree); + let params_path = Path::new(&path); + if File::open(¶ms_path).is_err() { + let params = ParamsKZG::::setup(degree, rng); + let mut buf = Vec::new(); + + params.write(&mut buf).expect("Failed to write params"); + let mut file = File::create(¶ms_path).expect("Failed to create params file"); + file + .write_all(&buf[..]) + .expect("Failed to write params to file"); + } + + let mut params_fs = File::open(¶ms_path).expect("couldn't load params"); + let params = ParamsKZG::::read(&mut params_fs).expect("Failed to read params"); + params +} + +pub fn serialize(data: &Vec, path: &str) -> u64 { + let mut file = File::create(path).unwrap(); + file.write_all(data).unwrap(); + file.metadata().unwrap().len() +} + +pub fn verify_kzg( + params: &ParamsKZG, + vk: &VerifyingKey, + strategy: SingleStrategy, + public_vals: &Vec, + mut transcript: Blake2bRead<&[u8], G1Affine, Challenge255>, +) { + assert!( + verify_proof::< + KZGCommitmentScheme, + VerifierSHPLONK<'_, Bn256>, + Challenge255, + Blake2bRead<&[u8], G1Affine, Challenge255>, + halo2_proofs::poly::kzg::strategy::SingleStrategy<'_, Bn256>, + >(¶ms, &vk, strategy, &[&[&public_vals]], &mut transcript) + .is_ok(), + "proof did not verify" + ); +} + +pub fn time_circuit_kzg(circuit: ModelCircuit) { + let rng = rand::thread_rng(); + let start = Instant::now(); + + let degree = circuit.k as u32; + let params = get_kzg_params("./params_kzg", degree); + + let circuit_duration = start.elapsed(); + println!( + "Time elapsed in params construction: {:?}", + circuit_duration + ); + + let vk_circuit = circuit.clone(); + let vk = keygen_vk(¶ms, &vk_circuit).unwrap(); + drop(vk_circuit); + let vk_duration = start.elapsed(); + println!( + "Time elapsed in generating vkey: {:?}", + vk_duration - circuit_duration + ); + + let vkey_size = serialize(&vk.to_bytes(SerdeFormat::RawBytes), "vkey"); + println!("vkey size: {} bytes", vkey_size); + + let pk_circuit = circuit.clone(); + let pk = keygen_pk(¶ms, vk, &pk_circuit).unwrap(); + let pk_duration = start.elapsed(); + println!( + "Time elapsed in generating pkey: {:?}", + pk_duration - vk_duration + ); + drop(pk_circuit); + + let pkey_size = serialize(&pk.to_bytes(SerdeFormat::RawBytes), "pkey"); + println!("pkey size: {} bytes", pkey_size); + + let fill_duration = start.elapsed(); + let proof_circuit = circuit.clone(); + let _prover = MockProver::run(degree, &proof_circuit, vec![vec![]]).unwrap(); + let public_vals = get_public_values(); + println!( + "Time elapsed in filling circuit: {:?}", + fill_duration - pk_duration + ); + + // Convert public vals to serializable format + let public_vals_u8: Vec = public_vals + .iter() + .map(|v: &Fr| v.to_bytes().to_vec()) + .flatten() + .collect(); + let public_vals_u8_size = serialize(&public_vals_u8, "public_vals"); + println!("Public vals size: {} bytes", public_vals_u8_size); + + let mut transcript = Blake2bWrite::<_, G1Affine, Challenge255<_>>::init(vec![]); + create_proof::< + KZGCommitmentScheme, + ProverSHPLONK<'_, Bn256>, + Challenge255, + _, + Blake2bWrite, G1Affine, Challenge255>, + ModelCircuit, + >( + ¶ms, + &pk, + &[proof_circuit], + &[&[&public_vals]], + rng, + &mut transcript, + ) + .unwrap(); + let proof = transcript.finalize(); + let proof_duration = start.elapsed(); + println!("Proving time: {:?}", proof_duration - fill_duration); + + let proof_size = serialize(&proof, "proof"); + let proof = std::fs::read("proof").unwrap(); + + println!("Proof size: {} bytes", proof_size); + + let strategy = SingleStrategy::new(¶ms); + let transcript_read = Blake2bRead::<_, _, Challenge255<_>>::init(&proof[..]); + + println!("public vals: {:?}", public_vals); + verify_kzg( + ¶ms, + &pk.get_vk(), + strategy, + &public_vals, + transcript_read, + ); + let verify_duration = start.elapsed(); + println!("Verifying time: {:?}", verify_duration - proof_duration); +} + +// Standalone verification +pub fn verify_circuit_kzg( + circuit: ModelCircuit, + vkey_fname: &str, + proof_fname: &str, + public_vals_fname: &str, +) { + let degree = circuit.k as u32; + let params = get_kzg_params("./params_kzg", degree); + println!("Loaded the parameters"); + + let vk = VerifyingKey::read::, ModelCircuit>( + &mut BufReader::new(File::open(vkey_fname).unwrap()), + SerdeFormat::RawBytes, + (), + ) + .unwrap(); + println!("Loaded vkey"); + + let proof = std::fs::read(proof_fname).unwrap(); + + let public_vals_u8 = std::fs::read(&public_vals_fname).unwrap(); + let public_vals: Vec = public_vals_u8 + .chunks(32) + .map(|chunk| Fr::from_bytes(chunk.try_into().expect("conversion failed")).unwrap()) + .collect(); + + let strategy = SingleStrategy::new(¶ms); + let transcript = Blake2bRead::<_, _, Challenge255<_>>::init(&proof[..]); + + let start = Instant::now(); + let verify_start = start.elapsed(); + verify_kzg(¶ms, &vk, strategy, &public_vals, transcript); + let verify_duration = start.elapsed(); + println!("Verifying time: {:?}", verify_duration - verify_start); + println!("Proof verified!") +}