Skip to content

Commit

Permalink
add scripts for qmb bounds derivation
Browse files Browse the repository at this point in the history
  • Loading branch information
freibold authored and dopitz committed Dec 16, 2024
1 parent c7d6668 commit da78d2c
Show file tree
Hide file tree
Showing 5 changed files with 5,729 additions and 0 deletions.
346 changes: 346 additions & 0 deletions scripts/qmbb/CodeGenTest.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,346 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import sympy as sp\n",
"import numpy as np\n",
"sp.init_printing()\n",
"\n",
"%load_ext cython"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"x0, x1, x2 = sp.symbols(\"x0, x1, x2\")\n",
"params = sp.MatrixSymbol('params', 3, 1)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"A = sp.Matrix([x0,\n",
" x1, \n",
" x2, \n",
" x2**2+x0+x1, \n",
" x1**2+x1+1, \n",
" x1**2+x0+1, \n",
" x2+1, \n",
" x0**2+x1+x2, \n",
" x2**2+x2+1, \n",
" x2**2+x0**2+x0+2, \n",
" x1+1, \n",
" x0**2+x0+x1])"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"param_map = dict(zip([x0, x1, x2], params))\n",
"B = A.xreplace(param_map)\n",
"R = sp.MatrixSymbol(\"coeff\", B.shape[0], B.shape[1])"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"from sympy.printing.ccode import C99CodePrinter\n",
"printer = C99CodePrinter()\n",
"\n",
"class CustomCodePrinter(C99CodePrinter):\n",
" def _print_Pow(self, expr):\n",
" if expr.exp.is_integer and expr.exp > 0 and expr.exp < 5:\n",
" return '*'.join([self._print(expr.base) for i in range(expr.exp)])\n",
" else:\n",
" return super()._print_Pow(expr)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"from sympy.utilities.codegen import codegen, default_datatypes"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/latex": [
"$\\displaystyle coeff = \\left[\\begin{matrix}params_{0, 0}\\\\params_{1, 0}\\\\params_{2, 0}\\\\params_{0, 0} + params_{1, 0} + params_{2, 0}^{2}\\\\params_{1, 0}^{2} + params_{1, 0} + 1\\\\params_{0, 0} + params_{1, 0}^{2} + 1\\\\params_{2, 0} + 1\\\\params_{0, 0}^{2} + params_{1, 0} + params_{2, 0}\\\\params_{2, 0}^{2} + params_{2, 0} + 1\\\\params_{0, 0}^{2} + params_{0, 0} + params_{2, 0}^{2} + 2\\\\params_{1, 0} + 1\\\\params_{0, 0}^{2} + params_{0, 0} + params_{1, 0}\\end{matrix}\\right]$"
],
"text/plain": [
" ⎡ params₀₀ ⎤\n",
" ⎢ ⎥\n",
" ⎢ params₁₀ ⎥\n",
" ⎢ ⎥\n",
" ⎢ params₂₀ ⎥\n",
" ⎢ ⎥\n",
" ⎢ 2 ⎥\n",
" ⎢ params₀₀ + params₁₀ + params₂₀ ⎥\n",
" ⎢ ⎥\n",
" ⎢ 2 ⎥\n",
" ⎢ params₁₀ + params₁₀ + 1 ⎥\n",
" ⎢ ⎥\n",
" ⎢ 2 ⎥\n",
" ⎢ params₀₀ + params₁₀ + 1 ⎥\n",
" ⎢ ⎥\n",
"coeff = ⎢ params₂₀ + 1 ⎥\n",
" ⎢ ⎥\n",
" ⎢ 2 ⎥\n",
" ⎢ params₀₀ + params₁₀ + params₂₀ ⎥\n",
" ⎢ ⎥\n",
" ⎢ 2 ⎥\n",
" ⎢ params₂₀ + params₂₀ + 1 ⎥\n",
" ⎢ ⎥\n",
" ⎢ 2 2 ⎥\n",
" ⎢params₀₀ + params₀₀ + params₂₀ + 2⎥\n",
" ⎢ ⎥\n",
" ⎢ params₁₀ + 1 ⎥\n",
" ⎢ ⎥\n",
" ⎢ 2 ⎥\n",
" ⎣ params₀₀ + params₀₀ + params₁₀ ⎦"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"sp.Eq(R,B)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"/******************************************************************************\n",
" * Code generated with sympy 1.4 *\n",
" * *\n",
" * See http://www.sympy.org/ for more information. *\n",
" * *\n",
" * This file is part of 'project' *\n",
" ******************************************************************************/\n",
"#include \"coefficients.h\"\n",
"#include <math.h>\n",
"\n",
"void coefficients(float *params, float *coeff) {\n",
"\n",
" coeff[0] = params[0];\n",
" coeff[1] = params[1];\n",
" coeff[2] = params[2];\n",
" coeff[3] = params[0] + params[1] + params[2]*params[2];\n",
" coeff[4] = params[1]*params[1] + params[1] + 1;\n",
" coeff[5] = params[0] + params[1]*params[1] + 1;\n",
" coeff[6] = params[2] + 1;\n",
" coeff[7] = params[0]*params[0] + params[1] + params[2];\n",
" coeff[8] = params[2]*params[2] + params[2] + 1;\n",
" coeff[9] = params[0]*params[0] + params[0] + params[2]*params[2] + 2;\n",
" coeff[10] = params[1] + 1;\n",
" coeff[11] = params[0]*params[0] + params[0] + params[1];\n",
"\n",
"}\n",
"\n"
]
}
],
"source": [
"from sympy.codegen.ast import real, float32\n",
"customprinter = CustomCodePrinter()\n",
"#customprinter.type_aliases[real] = float32 # cosf instead of cos\n",
"default_datatypes[\"float\"].cname = \"float\" # float instead of double\n",
"\n",
"[(cf, cs), (hf, hs)] = codegen(('coefficients', sp.Eq(R,B)), language='c', printer=customprinter)\n",
"print(cs)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"codegen(('coefficients', sp.Eq(R,B)), language='c', printer=customprinter, prefix='coefficients', to_files=True)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Overwriting cy_coefficients.pyxbld\n"
]
}
],
"source": [
"%%writefile cy_coefficients.pyxbld\n",
"import numpy\n",
"\n",
"# module name specified by `%%cython_pyximport` magic\n",
"# | just `modname + \".pyx\"`\n",
"# | |\n",
"def make_ext(modname, pyxfilename):\n",
" from setuptools.extension import Extension\n",
" return Extension(modname,\n",
" sources=[pyxfilename, 'coefficients.c'],\n",
" include_dirs=['.', numpy.get_include()])"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/nfs/ka/home/freibold/.local/lib/python3.7/site-packages/Cython/Compiler/Main.py:369: FutureWarning: Cython directive 'language_level' not set, using 2 for now (Py2). This will change in a later release! File: /nfs/ka/home/freibold/OneDrive/MotionCurve/cy_coefficients.pyx\n",
" tree = Parsing.p_module(s, pxd, full_module_name)\n"
]
}
],
"source": [
"%%cython_pyximport cy_coefficients\n",
"import numpy as np\n",
"cimport numpy as cnp # cimport gives us access to NumPy's C API\n",
"\n",
"# here we just replicate the function signature from the header\n",
"cdef extern from \"coefficients.h\":\n",
" void coefficients(float *params, float *result)\n",
"\n",
"# here is the \"wrapper\" signature that conforms to the odeint interface\n",
"def cy_coefficients(cnp.ndarray[cnp.float32_t, ndim=1] params, size):\n",
" # preallocate our output array\n",
" cdef cnp.ndarray[cnp.float32_t, ndim=1] result = np.empty(size, dtype='float32')\n",
" # now call the C function\n",
" coefficients(<float *> params.data, <float *> result.data)\n",
" # return the result\n",
" return result"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[ 1. 2. 3. 12. 7. 6. 4. 6. 13. 13. 3. 4.]\n"
]
}
],
"source": [
"params = np.array([1, 2, 3], dtype='float32')\n",
"result = cy_coefficients(params, 12)\n",
"print(result)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"data": {
"text/latex": [
"$\\displaystyle \\left[\\begin{matrix}1\\\\2\\\\3\\\\12\\\\7\\\\6\\\\4\\\\6\\\\13\\\\13\\\\3\\\\4\\end{matrix}\\right]$"
],
"text/plain": [
"⎡1 ⎤\n",
"⎢ ⎥\n",
"⎢2 ⎥\n",
"⎢ ⎥\n",
"⎢3 ⎥\n",
"⎢ ⎥\n",
"⎢12⎥\n",
"⎢ ⎥\n",
"⎢7 ⎥\n",
"⎢ ⎥\n",
"⎢6 ⎥\n",
"⎢ ⎥\n",
"⎢4 ⎥\n",
"⎢ ⎥\n",
"⎢6 ⎥\n",
"⎢ ⎥\n",
"⎢13⎥\n",
"⎢ ⎥\n",
"⎢13⎥\n",
"⎢ ⎥\n",
"⎢3 ⎥\n",
"⎢ ⎥\n",
"⎣4 ⎦"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"A.subs({(x0, 1), (x1, 2), (x2, 3)})"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.10"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Loading

0 comments on commit da78d2c

Please sign in to comment.