Skip to content

Commit

Permalink
Merge pull request #1409 from informalsystems/gabriela/type-quantific…
Browse files Browse the repository at this point in the history
…ation-fix

Fix the type quantification strategy
  • Loading branch information
bugarela authored Mar 22, 2024
2 parents cb460f3 + 401ca15 commit 60ed3ba
Show file tree
Hide file tree
Showing 5 changed files with 129 additions and 106 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- Removed a dependency causing deprecation errors messages to be emitted.
(#1380)
- Fixed a type checker bug causing too general types to be inferred (#1409).

### Security

Expand Down
5 changes: 1 addition & 4 deletions examples/cosmwasm/zero-to-hero/vote.qnt
Original file line number Diff line number Diff line change
Expand Up @@ -489,10 +489,7 @@ module state {
)
// assert that aggregated sum in `polls[poll_id]` equals the sum from above
val poll = state.polls.get(poll_id)
poll.options.listForall(option =>
val optionKey = option._1
// FIXME(#1167): Type annotation below is a workaround, inferred type is too general
val optionSum: int = option._2
poll.options.listForall(((optionKey, optionSum)) =>
// `ballots` only contains entries if there are > 0 votes.
optionSum > 0 implies and {
summed_ballots.keys().contains(optionKey),
Expand Down
5 changes: 3 additions & 2 deletions quint/src/types/base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,13 @@ export type Constraint =
*/
const constraintKinds = ['empty', 'eq', 'conjunction', 'isDefined'] as const

export interface TypeScheme {
type: QuintType
export interface QuantifiedVariables {
typeVariables: Set<string>
rowVariables: Set<string>
}

export type TypeScheme = { type: QuintType } & QuantifiedVariables

export type Signature = (_arity: number) => TypeScheme

// Does not bind any type variables in `type`, which we take to assume
Expand Down
192 changes: 108 additions & 84 deletions quint/src/types/constraintGenerator.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/* ----------------------------------------------------------------------------------
* Copyright 2022 Informal Systems
* Copyright 2022-2024 Informal Systems
* Licensed under the Apache License, Version 2.0.
* See LICENSE in the project root for license information.
* --------------------------------------------------------------------------------- */
Expand All @@ -18,7 +18,7 @@ import {
QuintAssume,
QuintBool,
QuintConst,
QuintDef,
QuintDeclaration,
QuintEx,
QuintInstance,
QuintInt,
Expand All @@ -35,7 +35,7 @@ import { expressionToString, rowToString, typeToString } from '../ir/IRprinting'
import { Either, left, mergeInMany, right } from '@sweet-monads/either'
import { Error, ErrorTree, buildErrorLeaf, buildErrorTree, errorTreeToString } from '../errorTree'
import { getSignatures } from './builtinSignatures'
import { Constraint, Signature, TypeScheme, toScheme } from './base'
import { Constraint, QuantifiedVariables, Signature, TypeScheme, toScheme } from './base'
import { Substitutions, applySubstitution, compose } from './substitutions'
import { LookupTable } from '../names/base'
import {
Expand Down Expand Up @@ -95,7 +95,7 @@ export class ConstraintGeneratorVisitor implements IRVisitor {
}
}

protected types: Map<bigint, TypeScheme> = new Map<bigint, TypeScheme>()
protected types: Map<bigint, TypeScheme> = new Map()
protected errors: Map<bigint, ErrorTree> = new Map<bigint, ErrorTree>()
protected freshVarGenerator: FreshVarGenerator
protected table: LookupTable
Expand All @@ -104,13 +104,14 @@ export class ConstraintGeneratorVisitor implements IRVisitor {
private solvingFunction: SolvingFunctionType
private builtinSignatures: Map<string, Signature> = getSignatures()

// A map to save which type variables were free when we started visiting an opdef or an assume
protected tvs: Map<bigint, QuantifiedVariables> = new Map()
// Temporary type map only for types in scope for a certain declaration
protected typesInScope: Map<bigint, TypeScheme> = new Map()

// Track location descriptions for error tree traces
private location: string = ''

// A stack of free type variables and row variables for lambda expressions.
// Nested lambdas add new entries to the stack, and pop them when exiting.
private freeNames: { typeVariables: Set<string>; rowVariables: Set<string> }[] = []

getResult(): [Map<bigint, ErrorTree>, Map<bigint, TypeScheme>] {
return [this.errors, this.types]
}
Expand All @@ -119,18 +120,6 @@ export class ConstraintGeneratorVisitor implements IRVisitor {
this.location = `Generating constraints for ${expressionToString(e)}`
}

exitDef(def: QuintDef) {
if (this.constraints.length > 0) {
this.solveConstraints().map(subs => {
if (isAnnotatedDef(def)) {
checkAnnotationGenerality(subs, def.typeAnnotation).mapLeft(err =>
this.errors.set(def.typeAnnotation?.id ?? def.id, err)
)
}
})
}
}

exitVar(e: QuintVar) {
this.addToResults(e.id, right(toScheme(e.typeAnnotation)))
}
Expand Down Expand Up @@ -242,22 +231,14 @@ export class ConstraintGeneratorVisitor implements IRVisitor {
}

enterLambda(expr: QuintLambda) {
const lastParamNames = this.currentFreeNames()
const paramNames = {
typeVariables: new Set(lastParamNames.typeVariables),
rowVariables: new Set(lastParamNames.rowVariables),
}
expr.params.forEach(p => {
const varName = p.name === '_' ? this.freshVarGenerator.freshVar('_t') : `t_${p.name}_${p.id}`
paramNames.typeVariables.add(varName)
const paramTypeVar: QuintVarType = { kind: 'var', name: varName }
this.addToResults(p.id, right(toScheme(paramTypeVar)))
if (p.typeAnnotation) {
this.constraints.push({ kind: 'eq', types: [paramTypeVar, p.typeAnnotation], sourceId: p.id })
}
})

this.freeNames.push(paramNames)
}

// Γ ∪ {p0: t0, ..., pn: tn} ⊢ e: (te, c) t0, ..., tn are fresh
Expand All @@ -281,7 +262,6 @@ export class ConstraintGeneratorVisitor implements IRVisitor {
})

this.addToResults(e.id, result)
this.freeNames.pop()
}

// Γ ⊢ e1: (t1, c1) s = solve(c1) s(Γ ∪ {n: t1}) ⊢ e2: (t2, c2)
Expand All @@ -292,22 +272,58 @@ export class ConstraintGeneratorVisitor implements IRVisitor {
return
}

// TODO: Occurs check on operator body to prevent recursion, see https://github.com/informalsystems/quint/issues/171

this.addToResults(e.id, this.fetchResult(e.expr.id))
}

exitOpDef(e: QuintOpDef) {
exitDecl(_def: QuintDeclaration) {
this.typesInScope = new Map()
}

enterOpDef(def: QuintOpDef) {
// Save which type variables were free when we started visiting this op def
const tvs = this.freeNamesInScope()
this.tvs.set(def.id, tvs)
}

exitOpDef(def: QuintOpDef) {
if (this.errors.size !== 0) {
return
}

this.fetchResult(e.expr.id).map(t => {
this.addToResults(e.id, right(this.quantify(t.type)))
if (e.typeAnnotation) {
this.constraints.push({ kind: 'eq', types: [t.type, e.typeAnnotation], sourceId: e.id })
this.fetchResult(def.expr.id).map(t => {
if (def.typeAnnotation) {
this.constraints.push({ kind: 'eq', types: [t.type, def.typeAnnotation], sourceId: def.id })
}
})

const tvs_before = this.tvs.get(def.id)!

if (this.constraints.length > 0) {
this.solveConstraints().map(subs => {
// For every free name we are binding in the substitutions, the names occuring in the value of the substitution
// have to become free as well.
addBindingsToFreeTypes(tvs_before, subs)

if (isAnnotatedDef(def)) {
checkAnnotationGenerality(subs, def.typeAnnotation).mapLeft(err =>
this.errors.set(def.typeAnnotation?.id ?? def.id, err)
)
}
})
}

const tvs = this.freeNamesInScope()
// Any new free names, that were not free before, have to be quantified
const toQuantify = variablesDifference(tvs, tvs_before)

this.fetchResult(def.expr.id).map(t => {
this.addToResults(def.id, right(quantify(toQuantify, t.type)))
})
}

enterAssume(e: QuintAssume) {
const tvs = this.freeNamesInScope()
this.tvs.set(e.id, tvs)
}

exitAssume(e: QuintAssume) {
Expand All @@ -316,15 +332,21 @@ export class ConstraintGeneratorVisitor implements IRVisitor {
}

this.fetchResult(e.assumption.id).map(t => {
this.addToResults(e.id, right(this.quantify(t.type)))
const tvs_before = this.tvs.get(e.id)!
const tvs = this.freeNamesInScope()
const toQuantify = variablesDifference(tvs, tvs_before)
this.addToResults(e.id, right(quantify(toQuantify, t.type)))
this.constraints.push({ kind: 'eq', types: [t.type, { kind: 'bool' }], sourceId: e.id })
})
}

private addToResults(exprId: bigint, result: Either<Error, TypeScheme>) {
result
.mapLeft(err => this.errors.set(exprId, buildErrorTree(this.location, err)))
.map(r => this.types.set(exprId, r))
.map(r => {
this.typesInScope.set(exprId, r)
this.types.set(exprId, r)
})
}

private fetchResult(id: bigint): Either<ErrorTree, TypeScheme> {
Expand All @@ -348,16 +370,9 @@ export class ConstraintGeneratorVisitor implements IRVisitor {
return this.solvingFunction(this.table, constraint)
.mapLeft(errors => errors.forEach((err, id) => this.errors.set(id, err)))
.map(subs => {
// For every free name we are binding in the substitutions, the names occuring in the value of the substitution
// have to become free as well.
this.addBindingsToFreeNames(subs)

// Apply substitution to environment
// FIXME: We have to figure out the scope of the constraints/substitutions
// https://github.com/informalsystems/quint/issues/690
this.types.forEach((oldScheme, id) => {
this.typesInScope.forEach((oldScheme, id) => {
const newType = applySubstitution(this.table, subs, oldScheme.type)
const newScheme: TypeScheme = this.quantify(newType)
const newScheme: TypeScheme = { ...oldScheme, type: newType }
this.addToResults(id, right(newScheme))
})

Expand Down Expand Up @@ -406,45 +421,18 @@ export class ConstraintGeneratorVisitor implements IRVisitor {
return applySubstitution(this.table, subs, t.type)
}

private currentFreeNames(): { typeVariables: Set<string>; rowVariables: Set<string> } {
return (
this.freeNames[this.freeNames.length - 1] ?? {
typeVariables: new Set(),
rowVariables: new Set(),
}
private freeNamesInScope(): QuantifiedVariables {
return [...this.typesInScope.values()].reduce(
(acc, t) => {
const names = freeTypes(t)
return {
typeVariables: new Set([...names.typeVariables, ...acc.typeVariables]),
rowVariables: new Set([...names.rowVariables, ...acc.rowVariables]),
}
},
{ typeVariables: new Set(), rowVariables: new Set() }
)
}

private quantify(type: QuintType): TypeScheme {
const freeNames = this.currentFreeNames()
const nonFreeNames = {
typeVariables: new Set([...typeNames(type).typeVariables].filter(name => !freeNames.typeVariables.has(name))),
rowVariables: new Set([...typeNames(type).rowVariables].filter(name => !freeNames.rowVariables.has(name))),
}
return { ...nonFreeNames, type }
}

private addBindingsToFreeNames(substitutions: Substitutions) {
// Assumes substitutions are topologically sorted, i.e. [ t0 |-> (t1, t2), t1 |-> (t3, t4) ]
substitutions.forEach(s => {
switch (s.kind) {
case 'type':
this.freeNames
.filter(free => free.typeVariables.has(s.name))
.forEach(free => {
const names = typeNames(s.value)
names.typeVariables.forEach(v => free.typeVariables.add(v))
names.rowVariables.forEach(v => free.rowVariables.add(v))
})
return
case 'row':
this.freeNames
.filter(free => free.rowVariables.has(s.name))
.forEach(free => rowNames(s.value).forEach(v => free.rowVariables.add(v)))
return
}
})
}
}

function checkAnnotationGenerality(
Expand Down Expand Up @@ -477,3 +465,39 @@ function checkAnnotationGenerality(
return right(subs)
}
}

function quantify(tvs: QuantifiedVariables, type: QuintType): TypeScheme {
return { ...tvs, type }
}

function freeTypes(t: TypeScheme): QuantifiedVariables {
const allNames = typeNames(t.type)
return variablesDifference(allNames, { typeVariables: t.typeVariables, rowVariables: t.rowVariables })
}

function addBindingsToFreeTypes(free: QuantifiedVariables, substitutions: Substitutions): void {
// Assumes substitutions are topologically sorted, i.e. [ t0 |-> (t1, t2), t1 |-> (t3, t4) ]
substitutions.forEach(s => {
switch (s.kind) {
case 'type':
if (free.typeVariables.has(s.name)) {
const names = typeNames(s.value)
names.typeVariables.forEach(v => free.typeVariables.add(v))
names.rowVariables.forEach(v => free.rowVariables.add(v))
}
return
case 'row':
if (free.rowVariables.has(s.name)) {
rowNames(s.value).forEach(v => free.rowVariables.add(v))
}
return
}
})
}

function variablesDifference(a: QuantifiedVariables, b: QuantifiedVariables): QuantifiedVariables {
return {
typeVariables: new Set([...a.typeVariables].filter(tv => !b.typeVariables.has(tv))),
rowVariables: new Set([...a.rowVariables].filter(tv => !b.rowVariables.has(tv))),
}
}
Loading

0 comments on commit 60ed3ba

Please sign in to comment.