Skip to content

Commit

Permalink
binding/python: generate io bindings
Browse files Browse the repository at this point in the history
Generate io bindings including prototypes in mpi_proto.h and
mpir_impl.h.

src/binding/c/io.c caintains all the IO binding functions that calls the
corresponding _impl functions. But it is not included in the Makefile
yet.

Dump the prototypes for io impl functions separately to allow inserting
necessary missing declarations such as MPIR_Ext_cs_enter/exit.

IO functions need use MPIR_Ext_cs_enter/exit for global cs. Now it gets
more complicated, refactor into dump_global_cs_enter/exit for better
readability.
  • Loading branch information
hzhou committed Dec 6, 2024
1 parent 05dc7e1 commit 46ebe51
Show file tree
Hide file tree
Showing 5 changed files with 103 additions and 21 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,7 @@ Makefile.am-stamp

# /src/binding/
/src/binding/c/c_binding.c
/src/binding/c/io.c

# /src/binding/cxx/
/src/binding/cxx/Makefile.sm
Expand Down
22 changes: 19 additions & 3 deletions maint/gen_binding_c.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,6 @@ def main():
abi_dir = "src/binding/abi"
func_list = load_C_func_list(binding_dir)

io_func_list = [f for f in func_list if f['dir'] == 'io']
func_list = [f for f in func_list if f['dir'] != 'io']

# -- Loading extra api prototypes (needed until other `buildiface` scripts are updated)
G.mpi_declares = []

Expand Down Expand Up @@ -53,6 +50,10 @@ def main():
repl_func['_replaces'].append(func)


# We generate io functions separately for now
io_func_list = [f for f in func_list if f['dir'] == 'io']
func_list = [f for f in func_list if f['dir'] != 'io']

# -- Generating code --
G.doc3_src_txt = []
G.poly_aliases = [] # large-count mansrc aliases
Expand Down Expand Up @@ -144,9 +145,23 @@ def dump_c_binding_abi():
G.check_write_path(abi_file_path)
dump_c_file(abi_file_path, G.out)

def dump_io_funcs():
G.out = []
G.out.append("#include \"mpiimpl.h\"")
G.out.append("#include \"mpir_io_impl.h\"")
G.out.append("")

for func in io_func_list:
G.err_codes = {}
manpage_out = []
dump_func(func, manpage_out)

dump_out(c_dir + "/io.c")

# ----
dump_c_binding()
dump_c_binding_abi()
dump_io_funcs()

if 'output-mansrc' in G.opts:
f = c_dir + '/mansrc/' + 'poly_aliases.lst'
Expand All @@ -160,6 +175,7 @@ def dump_c_binding_abi():
G.check_write_path("src/include/mpi_proto.h")
dump_Makefile_mk("%s/Makefile.mk" % c_dir)
dump_mpir_impl_h("src/include/mpir_impl.h")
dump_mpir_io_impl_h("src/include/mpir_io_impl.h")
dump_errnames_txt("%s/errnames.txt" % c_dir)
dump_qmpi_register_h("src/mpi_t/qmpi_register.h")
dump_mpi_proto_h("src/include/mpi_proto.h")
Expand Down
1 change: 1 addition & 0 deletions maint/local_python/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def check_write_path(path):
mpi_sources = []
mpi_declares = []
impl_declares = []
io_impl_declares = []
mpi_errnames = []
mpix_symbols = {}

Expand Down
97 changes: 80 additions & 17 deletions maint/local_python/binding_c.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,35 @@ def dump_mpix_symbols():
print("", file=Out)
print("#endif /* MPIR_IMPL_H_INCLUDED */", file=Out)

def dump_mpir_io_impl_h(f):
print(" --> [%s]" %f)
with open(f, "w") as Out:
for l in G.copyright_c:
print(l, file=Out)
print("#ifndef MPIR_IO_IMPL_H_INCLUDED", file=Out)
print("#define MPIR_IO_IMPL_H_INCLUDED", file=Out)

print("", file=Out)
print("#define MPIR_ERR_RECOVERABLE 0", file=Out)
print("#define MPIR_ERR_FATAL 1", file=Out)
print("int MPIR_Err_create_code(int, int, const char[], int, int, const char[], const char[], ...);", file=Out)
print("void MPIR_Ext_cs_enter(void);", file=Out)
print("void MPIR_Ext_cs_exit(void);", file=Out)
print("#ifndef HAVE_ROMIO", file=Out)
print("int MPIR_Err_return_comm(void *, const char[], int);", file=Out)
print("#define MPIO_Err_return_file(fh, errorcode) MPIR_Err_return_comm((void *)0, __func__, errorcode)", file=Out)
print("#else", file=Out)
print("MPI_Fint MPIR_File_c2f_impl(MPI_File fh);", file=Out)
print("MPI_File MPIR_File_f2c_impl(MPI_Fint fh);", file=Out)
print("int MPIO_Err_return_file(MPI_File fh, int errorcode);", file=Out)
print("#endif", file=Out)

print("", file=Out)
for l in G.io_impl_declares:
print(l, file=Out)
print("", file=Out)
print("#endif /* MPIR_IO_IMPL_H_INCLUDED */", file=Out)

def filter_out_abi():
funcname = None
for l in G.out:
Expand Down Expand Up @@ -552,8 +581,10 @@ def process_func_parameters(func):

do_handle_ptr = 0
(kind, name) = (p['kind'], p['name'])
if '_has_comm' not in func and kind == "COMMUNICATOR" and p['param_direction'] == 'in':
if '_has_comm' not in func and kind == "COMMUNICATOR" and p['param_direction'] == 'in' and func['name'] != "MPI_File_open":
func['_has_comm'] = name
elif kind == "FILE" and p['param_direction'] == 'in' and func['dir'] == 'io':
func['_has_file'] = name
elif name == "win":
func['_has_win'] = name
elif name == "session":
Expand Down Expand Up @@ -733,15 +764,19 @@ def process_func_parameters(func):
elif RE.match(r'F90_(COMM|ERRHANDLER|FILE|GROUP|INFO|MESSAGE|OP|REQUEST|SESSION|DATATYPE|WIN)', kind):
# no validation for these kinds
pass
elif RE.match(r'(POLY)?(DTYPE_STRIDE_BYTES|DISPLACEMENT_AINT_COUNT)$', kind):
elif RE.match(r'(POLY)?(DTYPE_STRIDE_BYTES|DISPLACEMENT_AINT_COUNT|OFFSET)$', kind):
# e.g. stride in MPI_Type_vector, MPI_Type_create_resized
pass
elif is_pointer_type(p):
validation_list.append({'kind': "ARGNULL", 'name': name})
else:
print("Missing error checking: func=%s, name=%s, kind=%s" % (func_name, name, kind), file=sys.stderr)

if do_handle_ptr == 1:
if func['dir'] == 'io':
# pass io function parameters as is
impl_arg_list.append(name)
impl_param_list.append(get_impl_param(func, p))
elif do_handle_ptr == 1:
if p['param_direction'] == 'inout':
# assume only one such parameter
func['_has_handle_inout'] = p
Expand Down Expand Up @@ -1387,6 +1422,25 @@ def check_large_parameters(func):
func['_poly_in_list'].append(p)

def dump_function_normal(func):
def dump_global_cs_enter():
if not '_skip_global_cs' in func:
G.out.append("")
if func['dir'] == 'mpit':
G.out.append("MPIR_T_THREAD_CS_ENTER();")
elif func['dir'] == 'io':
G.out.append("MPIR_Ext_cs_enter();")
else:
G.out.append("MPID_THREAD_CS_ENTER(GLOBAL, MPIR_THREAD_GLOBAL_ALLFUNC_MUTEX);")
def dump_global_cs_exit():
if not '_skip_global_cs' in func:
G.out.append("")
if func['dir'] == 'mpit':
G.out.append("MPIR_T_THREAD_CS_EXIT();")
elif func['dir'] == 'io':
G.out.append("MPIR_Ext_cs_exit();")
else:
G.out.append("MPID_THREAD_CS_EXIT(GLOBAL, MPIR_THREAD_GLOBAL_ALLFUNC_MUTEX);")
# ----
G.out.append("int mpi_errno = MPI_SUCCESS;")
if '_handle_ptr_list' in func:
for p in func['_handle_ptr_list']:
Expand All @@ -1405,12 +1459,7 @@ def dump_function_normal(func):
else:
G.out.append("MPIR_ERRTEST_INITIALIZED_ORDIE();")

if not '_skip_global_cs' in func:
G.out.append("")
if func['dir'] == 'mpit':
G.out.append("MPIR_T_THREAD_CS_ENTER();")
else:
G.out.append("MPID_THREAD_CS_ENTER(GLOBAL, MPIR_THREAD_GLOBAL_ALLFUNC_MUTEX);")
dump_global_cs_enter()
G.out.append("MPIR_FUNC_TERSE_ENTER;")

if '_handle_ptr_list' in func:
Expand Down Expand Up @@ -1476,8 +1525,19 @@ def dump_function_normal(func):
G.out.append("goto fn_exit;")
dump_if_close()

need_endif = False
if func['dir'] == 'io' and not 'return' in func:
G.out.append("#ifndef HAVE_ROMIO")
G.out.append("mpi_errno = MPIR_Err_create_code(mpi_errno, MPIR_ERR_RECOVERABLE, __func__, __LINE__, MPI_ERR_OTHER, \"**notimpl\", 0);")
G.out.append("goto fn_fail;")
G.out.append("#else")
need_endif = True

dump_body_of_routine(func)

if need_endif:
G.out.append("#endif")

G.out.append("/* ... end of body of routine ... */")

# ----
Expand All @@ -1486,12 +1546,8 @@ def dump_function_normal(func):
for l in func['_clean_up']:
G.out.append(l)
G.out.append("MPIR_FUNC_TERSE_EXIT;")
dump_global_cs_exit()

if not '_skip_global_cs' in func:
if func['dir'] == 'mpit':
G.out.append("MPIR_T_THREAD_CS_EXIT();")
else:
G.out.append("MPID_THREAD_CS_EXIT(GLOBAL, MPIR_THREAD_GLOBAL_ALLFUNC_MUTEX);")
G.out.append("return mpi_errno;")
G.out.append("")
G.out.append("fn_fail:")
Expand Down Expand Up @@ -1635,7 +1691,10 @@ def push_impl_decl(func, impl_name=None):
mpir_name = re.sub(r'^MPIX?_', 'MPIR_', func['name'])
G.impl_declares.append("int %s(%s);" % (mpir_name, params))
# dump MPIR_Xxx_impl(...)
G.impl_declares.append("int %s(%s);" % (impl_name, params))
if func['dir'] == 'io':
G.io_impl_declares.append("int %s(%s);" % (impl_name, params))
else:
G.impl_declares.append("int %s(%s);" % (impl_name, params))

def push_threadcomm_impl_decl(func):
impl_name = re.sub(r'^mpix?_(comm_)?', 'MPIR_Threadcomm_', func['name'].lower())
Expand Down Expand Up @@ -2002,6 +2061,8 @@ def dump_mpi_fn_fail(func):
G.out.append("mpi_errno = MPIR_Err_return_comm(%s_ptr, __func__, mpi_errno);" % func['_has_comm'])
elif '_has_win' in func:
G.out.append("mpi_errno = MPIR_Err_return_win(win_ptr, __func__, mpi_errno);")
elif '_has_file' in func:
G.out.append("mpi_errno = MPIO_Err_return_file(%s, mpi_errno);" % func['_has_file'])
elif RE.match(r'mpi_session_init', func['name'], re.IGNORECASE):
G.out.append("mpi_errno = MPIR_Err_return_session_init(errhandler_ptr, __func__, mpi_errno);")
elif '_has_session' in func:
Expand All @@ -2010,6 +2071,8 @@ def dump_mpi_fn_fail(func):
G.out.append("mpi_errno = MPIR_Err_return_comm_create_from_group(errhandler_ptr, __func__, mpi_errno);")
elif '_has_group' in func:
G.out.append("mpi_errno = MPIR_Err_return_group(%s_ptr, __func__, mpi_errno);" % func['_has_group'])
elif re.match(r'MPI_File_(delete|open|close)', func['name']):
G.out.append("mpi_errno = MPIO_Err_return_file(MPI_FILE_NULL, mpi_errno);")
else:
G.out.append("mpi_errno = MPIR_Err_return_comm(0, __func__, mpi_errno);")

Expand All @@ -2034,9 +2097,9 @@ def get_fn_fail_create_code(func):
fmt = 'p'
elif kind in fmt_codes:
fmt = fmt_codes[kind]
elif mapping[kind] == "int":
elif mapping[kind] == "int" or mapping[kind] == "MPI_Fint":
fmt = 'd'
elif mapping[kind] == "MPI_Aint":
elif mapping[kind] == "MPI_Aint" or mapping[kind] == "MPI_Offset":
fmt = 'L'
elif mapping[kind] == "MPI_Count":
fmt = 'c'
Expand Down
3 changes: 2 additions & 1 deletion src/binding/abi/Makefile.mk
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ include_HEADERS += src/binding/abi/mpi_abi.h

mpi_abi_sources += \
src/binding/abi/mpi_abi_util.c \
src/binding/abi/c_binding_abi.c
src/binding/abi/c_binding_abi.c \
src/binding/abi/io_abi.c

endif BUILD_ABI_LIB

0 comments on commit 46ebe51

Please sign in to comment.