From ebff0d818b4b369835744f696ada4eda135bfc47 Mon Sep 17 00:00:00 2001 From: Jonathan Protzenko Date: Fri, 27 Dec 2024 10:35:06 -0800 Subject: [PATCH] Renormalize data types properly --- lib/Checker.ml | 12 ++++++++---- lib/Monomorphization.ml | 18 ++++++++++++++++++ lib/MonomorphizationState.ml | 30 +++++++++++++++++++++++++++++- 3 files changed, 55 insertions(+), 5 deletions(-) diff --git a/lib/Checker.ml b/lib/Checker.ml index 54b2597d..81cc31a7 100644 --- a/lib/Checker.ml +++ b/lib/Checker.ml @@ -1089,16 +1089,18 @@ and assert_cons_of env t id: fields_t = checker_error env "the annotated type %a is not a variant type" ptyp (TAnonymous t) and subtype env t1 t2 = - if Options.debug "checker" then - KPrint.bprintf "%a <=? %a\n" ptyp t1 ptyp t2; let normalize t = - match MonomorphizationState.resolve (expand_abbrev env t) with + match MonomorphizationState.resolve_deep (expand_abbrev env t) with | TBuf (TApp ((["Eurydice"], "derefed_slice"), [ t ]), _) -> TApp ((["Eurydice"], "slice"), [t]) | t -> t in - match normalize t1, normalize t2 with + let t1 = normalize t1 in + let t2 = normalize t2 in + if Options.debug "checker" then + KPrint.bprintf "%a <=? %a\n" ptyp t1 ptyp t2; + match t1, t2 with | TInt w1, TInt w2 when w1 = w2 -> true | TInt K.SizeT, TInt K.UInt32 when Options.wasm () -> @@ -1184,6 +1186,8 @@ and subtype env t1 t2 = subtype env t2 t1 | _ -> + if Options.debug "checker" then + MonomorphizationState.debug (); false and eqtype env t1 t2 = diff --git a/lib/Monomorphization.ml b/lib/Monomorphization.ml index 82eed717..80abf7f6 100644 --- a/lib/Monomorphization.ml +++ b/lib/Monomorphization.ml @@ -298,6 +298,17 @@ let monomorphize_data_types map = object(self) TQualified chosen_lid end)#visit_typ () t + (* We need to renormalize entries in the map for the Checker module. For + instance, the map might contain `t (u v) -> t0` and `u v -> u0`, but at + this stage, we will have a type error when trying to compare `t (u v)` and + `t u0`, since the latter does not appear in the map. *) + method private renormalize_entry (n, ts, cgs) chosen_lid = + (* We do this on the fly to make sure that types that appear in ts have + themselves been renormalized. *) + let ts' = List.map resolve_deep ts in + if not (Hashtbl.mem state (n, ts', cgs)) then + Hashtbl.add state (n, ts', cgs) (Black, chosen_lid) + (* Compute the name of a given node in the graph. *) method private lid_of (n: node) = let lid, ts, cgs = n in @@ -340,6 +351,7 @@ let monomorphize_data_types map = object(self) (* For tuples, we immediately know how to generate a definition. *) let fields = List.mapi (fun i arg -> Some (self#field_at i), (arg, false)) args in self#record (DType (chosen_lid, [ Common.Private ] @ flag, 0, 0, Flat fields)); + self#renormalize_entry n chosen_lid; Hashtbl.replace state n (Black, chosen_lid) end else begin (* This specific node has not been visited yet. *) @@ -352,6 +364,7 @@ let monomorphize_data_types map = object(self) begin match Hashtbl.find map lid with | exception Not_found -> (* Unknown, external non-polymorphic lid, e.g. Prims.int *) + self#renormalize_entry n chosen_lid; Hashtbl.replace state n (Black, chosen_lid) | flags, ((Variant _ | Flat _ | Union _) as def) when under_ref && not (Hashtbl.mem seen_declarations lid) -> (* Because this looks up a definition in the global map, the @@ -382,10 +395,12 @@ let monomorphize_data_types map = object(self) let branches = List.map (fun (cons, fields) -> cons, subst fields) branches in let branches = self#visit_branches_t under_ref branches in self#record (DType (chosen_lid, flag @ flags, 0, 0, Variant branches)); + self#renormalize_entry n chosen_lid; Hashtbl.replace state n (Black, chosen_lid) | flags, Flat fields -> let fields = self#visit_fields_t_opt under_ref (subst fields) in self#record (DType (chosen_lid, flag @ flags, 0, 0, Flat fields)); + self#renormalize_entry n chosen_lid; Hashtbl.replace state n (Black, chosen_lid) | flags, Union fields -> let fields = List.map (fun (f, t) -> @@ -394,13 +409,16 @@ let monomorphize_data_types map = object(self) f, t ) fields in self#record (DType (chosen_lid, flag @ flags, 0, 0, Union fields)); + self#renormalize_entry n chosen_lid; Hashtbl.replace state n (Black, chosen_lid) | flags, Abbrev t -> let t = DeBruijn.subst_tn args t in let t = self#visit_typ under_ref t in self#record (DType (chosen_lid, flag @ flags, 0, 0, Abbrev t)); + self#renormalize_entry n chosen_lid; Hashtbl.replace state n (Black, chosen_lid) | _ -> + self#renormalize_entry n chosen_lid; Hashtbl.replace state n (Black, chosen_lid) end end; diff --git a/lib/MonomorphizationState.ml b/lib/MonomorphizationState.ml index c75e378b..4adfa45a 100644 --- a/lib/MonomorphizationState.ml +++ b/lib/MonomorphizationState.ml @@ -1,8 +1,23 @@ open Ast +open PrintAst.Ops + +(* Various bits of state for monomorphization, the two most important being + `state` (type monomorphization) and `generated_lids` (function + monomorphization). *) + +(* Monomorphization of data types. *) type node = lident * typ list * cg list type color = Gray | Black + +(* Each polymorphic type `lid` applied to types `ts` and const generics `ts` + appears in `state`, and maps to `monomorphized_lid`, the name of its + monomorphized instance. *) let state: (node, color * lident) Hashtbl.t = Hashtbl.create 41 +(* Because of polymorphic externals, one still encounters, + post-monomorphizations, application nodes in types (e.g. after instantiating + a polymorphic type scheme). The `resolve*` functions, below, normalize a type + to only contain monomorphic type names (and no more type applications) *) let resolve t: typ = match t with | TApp _ | TCgApp _ when Hashtbl.mem state (flatten_tapp t) -> @@ -27,4 +42,17 @@ let resolve_deep = (object(self) resolve (TTuple ts) end)#visit_typ () -let generated_lids: (lident * expr list * typ list, lident) Hashtbl.t = Hashtbl.create 41 +(* Monomorphization of functions *) +type reverse_mapping = (lident * expr list * typ list, lident) Hashtbl.t + +let generated_lids: reverse_mapping = Hashtbl.create 41 + +let debug () = + Hashtbl.iter (fun (lid, ts, cgs) (_, monomorphized_lid) -> + KPrint.bprintf "%a <%a> <%a> ~~> %a\n" plid lid pcgs cgs ptyps ts plid + monomorphized_lid + ) state; + Hashtbl.iter (fun (lid, es, ts) monomorphized_lid -> + KPrint.bprintf "%a <%a> <%a> ~~> %a\n" plid lid pexprs es ptyps ts plid + monomorphized_lid + ) generated_lids