Skip to content

Commit

Permalink
Static Object Optimizations (#197)
Browse files Browse the repository at this point in the history
- This PR uses the Parser & StaticOptimizer thread local string interner
for keys in static objects
- Similarly, we deduplicate the String -> Boolean map used to determine
if a field is static.
- For static objects we also use immutable.VectorMap (with a
JavaWrapper) for the field set.
- Lastly for the value cache, we size it according to the number of keys
in the object this reduces unnecessary up-sizing for large objects, and
more importantly removes the large number of sparse maps we previously
had for small objects (the default was 16 elements)

Before: 855MB for the parsed file
![image
(9)](https://github.com/databricks/sjsonnet/assets/153361/285bec71-8704-4e0b-98cd-336332e1d22c)

After: 425MB
![image
(10)](https://github.com/databricks/sjsonnet/assets/153361/127f4c76-429c-4049-b881-214942368a71)

---------

Co-authored-by: Li Haoyi <32282535+lihaoyi-databricks@users.noreply.github.com>
  • Loading branch information
ahirreddy and lihaoyi-databricks authored Jan 9, 2024
1 parent 74bbe26 commit c462833
Show file tree
Hide file tree
Showing 10 changed files with 140 additions and 22 deletions.
54 changes: 54 additions & 0 deletions bench/src/main/scala/sjsonnet/MainBenchmark.scala
Original file line number Diff line number Diff line change
Expand Up @@ -66,3 +66,57 @@ class MainBenchmark {
))
}
}

// This is a dummy benchmark to see how much memory is used by the interpreter.
// You're meant to execute it, and it will generate stats about memory usage before exiting.
// dump. You can optionally pass an argument to instruct it to pause post run - then attach a
// profiler.
object MemoryBenchmark {

val dummyOut = MainBenchmark.createDummyOut

val cache = new DefaultParseCache

def main(args: Array[String]): Unit = {
assert(args.length <= 1, s"Too many arguments: ${args.mkString(",")}")
val pause: Boolean = if (args.length == 1) {
if (args(0) == "--pause") {
println("Run will pause after completion. Attach a profiler then.")
true
} else {
println("Unknown argument: " + args(0))
System.exit(1)
false
}
} else {
false
}
SjsonnetMain.main0(
MainBenchmark.mainArgs,
cache,
System.in,
dummyOut,
System.err,
os.pwd,
None
)
println("Pre-GC Stats")
println("============")
println("Total memory: " + Runtime.getRuntime.totalMemory())
println("Free memory: " + Runtime.getRuntime.freeMemory())
println("Used memory: " + (Runtime.getRuntime.totalMemory() - Runtime.getRuntime.freeMemory()))
System.gc()
// Wait for GC to finish
Thread.sleep(5000)
println("Post-GC Stats")
println("============")
println("Total memory: " + Runtime.getRuntime.totalMemory())
println("Free memory: " + Runtime.getRuntime.freeMemory())
println("Used memory: " + (Runtime.getRuntime.totalMemory() - Runtime.getRuntime.freeMemory()))

if (pause) {
println("Pausing. Attach a profiler")
Thread.sleep(1000000000)
}
}
}
8 changes: 5 additions & 3 deletions bench/src/main/scala/sjsonnet/OptimizerBenchmark.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ package sjsonnet
import java.io.StringWriter
import java.util.concurrent.TimeUnit

import scala.collection.mutable

import fastparse.Parsed.Success
import org.openjdk.jmh.annotations._
import org.openjdk.jmh.infra._
Expand All @@ -26,13 +28,13 @@ class OptimizerBenchmark {
def setup(): Unit = {
val (allFiles, ev) = MainBenchmark.findFiles()
this.inputs = allFiles.map { case (p, s) =>
fastparse.parse(s, new Parser(p, true).document(_)) match {
fastparse.parse(s, new Parser(p, true, mutable.HashMap.empty, mutable.HashMap.empty).document(_)) match {
case Success(v, _) => v
}
}
this.ev = ev
val static = inputs.map {
case (expr, fs) => ((new StaticOptimizer(ev, new Std().Std)).optimize(expr), fs)
case (expr, fs) => ((new StaticOptimizer(ev, new Std().Std, mutable.HashMap.empty, mutable.HashMap.empty)).optimize(expr), fs)
}
val countBefore, countStatic = new Counter
inputs.foreach(t => assert(countBefore.transform(t._1) eq t._1))
Expand All @@ -45,7 +47,7 @@ class OptimizerBenchmark {
@Benchmark
def main(bh: Blackhole): Unit = {
bh.consume(inputs.foreach { case (expr, fs) =>
bh.consume((new StaticOptimizer(ev, new Std().Std)).optimize(expr))
bh.consume((new StaticOptimizer(ev, new Std().Std, mutable.HashMap.empty, mutable.HashMap.empty)).optimize(expr))
})
}

Expand Down
4 changes: 3 additions & 1 deletion bench/src/main/scala/sjsonnet/ParserBenchmark.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ package sjsonnet

import java.util.concurrent.TimeUnit

import scala.collection.mutable.HashMap

import fastparse.Parsed.Success
import org.openjdk.jmh.annotations._
import org.openjdk.jmh.infra._
Expand All @@ -25,7 +27,7 @@ class ParserBenchmark {
@Benchmark
def main(bh: Blackhole): Unit = {
bh.consume(allFiles.foreach { case (p, s) =>
val res = fastparse.parse(s, new Parser(p, true).document(_))
val res = fastparse.parse(s, new Parser(p, true, HashMap.empty, HashMap.empty).document(_))
bh.consume(res.asInstanceOf[Success[_]])
})
}
Expand Down
13 changes: 13 additions & 0 deletions readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,19 @@ Profiler:
sbt bench/run
```

There's also a benchmark for memory usage:

Execute and print stats:
```
sbt 'set fork in run := true' 'set javaOptions in run ++= Seq("-Xmx6G", "-XX:+UseG1GC")' 'bench/runMain sjsonnet.MemoryBenchmark'
```

Execute and pause - this is useful if you want to attach a profiler after the run and deep dive the
object utilization.
```
sbt 'set fork in run := true' 'set javaOptions in run ++= Seq("-Xmx6G", "-XX:+UseG1GC")' 'bench/runMain sjsonnet.MemoryBenchmark --pause'
```

## Laziness

The Jsonnet language is _lazy_: expressions don't get evaluated unless
Expand Down
6 changes: 4 additions & 2 deletions sjsonnet/src/sjsonnet/Importer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -180,11 +180,13 @@ class CachedImporter(parent: Importer) extends Importer {
class CachedResolver(
parentImporter: Importer,
val parseCache: ParseCache,
strictImportSyntax: Boolean) extends CachedImporter(parentImporter) {
strictImportSyntax: Boolean,
internedStrings: mutable.HashMap[String, String],
internedStaticFieldSets: mutable.HashMap[Val.StaticObjectFieldSet, java.util.LinkedHashMap[String, java.lang.Boolean]]) extends CachedImporter(parentImporter) {

def parse(path: Path, content: ResolvedFile)(implicit ev: EvalErrorScope): Either[Error, (Expr, FileScope)] = {
parseCache.getOrElseUpdate((path, content.contentHash.toString), {
val parsed = fastparse.parse(content.getParserInput(), new Parser(path, strictImportSyntax).document(_)) match {
val parsed = fastparse.parse(content.getParserInput(), new Parser(path, strictImportSyntax, internedStrings, internedStaticFieldSets).document(_)) match {
case f @ Parsed.Failure(_, _, _) =>
val traced = f.trace()
val pos = new Position(new FileScope(path), traced.index)
Expand Down
10 changes: 8 additions & 2 deletions sjsonnet/src/sjsonnet/Interpreter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ package sjsonnet

import java.io.{PrintWriter, StringWriter}

import scala.collection.mutable

import sjsonnet.Expr.Params

import scala.util.control.NonFatal
Expand All @@ -21,9 +23,13 @@ class Interpreter(extVars: Map[String, String],
std: Val.Obj = new Std().Std
) { self =>

val resolver = new CachedResolver(importer, parseCache, settings.strictImportSyntax) {
private val internedStrings = new mutable.HashMap[String, String]

private val internedStaticFieldSets = new mutable.HashMap[Val.StaticObjectFieldSet, java.util.LinkedHashMap[String, java.lang.Boolean]]

val resolver = new CachedResolver(importer, parseCache, settings.strictImportSyntax, internedStrings, internedStaticFieldSets) {
override def process(expr: Expr, fs: FileScope): Either[Error, (Expr, FileScope)] =
handleException(new StaticOptimizer(evaluator, std).optimize(expr), fs)
handleException(new StaticOptimizer(evaluator, std, internedStrings, internedStaticFieldSets).optimize(expr), fs)
}

private def warn(e: Error): Unit = warnLogger("[warning] " + formatError(e))
Expand Down
10 changes: 5 additions & 5 deletions sjsonnet/src/sjsonnet/Parser.scala
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,13 @@ object Parser {
}

class Parser(val currentFile: Path,
strictImportSyntax: Boolean) {
strictImportSyntax: Boolean,
internedStrings: mutable.HashMap[String, String],
internedStaticFieldSets: mutable.HashMap[Val.StaticObjectFieldSet, java.util.LinkedHashMap[String, java.lang.Boolean]]) {
import Parser._

private val fileScope = new FileScope(currentFile)

private val strings = new mutable.HashMap[String, String]

def Pos[_: P]: P[Position] = Index.map(offset => new Position(fileScope, offset))

def id[_: P] = P(
Expand Down Expand Up @@ -264,7 +264,7 @@ class Parser(val currentFile: Path,

def constructString(pos: Position, lines: Seq[String]) = {
val s = lines.mkString
val unique = strings.getOrElseUpdate(s, s)
val unique = internedStrings.getOrElseUpdate(s, s)
Val.Str(pos, unique)
}

Expand Down Expand Up @@ -335,7 +335,7 @@ class Parser(val currentFile: Path,
val a = exprs.iterator.filter(_.isInstanceOf[Expr.Member.AssertStmt]).asInstanceOf[Iterator[Expr.Member.AssertStmt]].toArray
if(a.isEmpty) null else a
}
if(binds == null && asserts == null && fields.forall(_.isStatic)) Val.staticObject(pos, fields)
if(binds == null && asserts == null && fields.forall(_.isStatic)) Val.staticObject(pos, fields, internedStaticFieldSets, internedStrings)
else Expr.ObjBody.MemberList(pos, binds, fields, asserts)
case (pos, exprs, Some(comps)) =>
val preLocals = exprs
Expand Down
10 changes: 8 additions & 2 deletions sjsonnet/src/sjsonnet/StaticOptimizer.scala
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
package sjsonnet

import scala.collection.mutable

import Expr._
import ScopedExprTransform._

/** StaticOptimizer performs necessary transformations for the evaluator (assigning ValScope
* indices) plus additional optimizations (post-order) and static checking (pre-order). */
class StaticOptimizer(ev: EvalScope, std: Val.Obj) extends ScopedExprTransform {
class StaticOptimizer(
ev: EvalScope,
std: Val.Obj,
internedStrings: mutable.HashMap[String, String],
internedStaticFieldSets: mutable.HashMap[Val.StaticObjectFieldSet, java.util.LinkedHashMap[String, java.lang.Boolean]]) extends ScopedExprTransform {
def optimize(e: Expr): Expr = transform(e)

def failOrWarn(msg: String, expr: Expr): Expr = {
Expand Down Expand Up @@ -61,7 +67,7 @@ class StaticOptimizer(ev: EvalScope, std: Val.Obj) extends ScopedExprTransform {
new Val.Arr(a.pos, a.value.map(e => e.asInstanceOf[Val]))

case m @ ObjBody.MemberList(pos, binds, fields, asserts) =>
if(binds == null && asserts == null && fields.forall(_.isStatic)) Val.staticObject(pos, fields)
if(binds == null && asserts == null && fields.forall(_.isStatic)) Val.staticObject(pos, fields, internedStaticFieldSets, internedStrings)
else m

case e => e
Expand Down
41 changes: 36 additions & 5 deletions sjsonnet/src/sjsonnet/Val.scala
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package sjsonnet

import java.util
import java.util.Arrays

import sjsonnet.Expr.Member.Visibility
import sjsonnet.Expr.Params
Expand Down Expand Up @@ -297,15 +298,45 @@ object Val{
}
}

def staticObject(pos: Position, fields: Array[Expr.Member.Field]): Obj = {
final class StaticObjectFieldSet(protected val keys: Array[String]) {

override def hashCode(): Int = {
Arrays.hashCode(keys.asInstanceOf[Array[Object]])
}

override def equals(obj: scala.Any): Boolean = {
obj match {
case that: StaticObjectFieldSet =>
keys.sameElements(that.keys)
case _ => false
}
}
}

def staticObject(
pos: Position,
fields: Array[Expr.Member.Field],
internedKeyMaps: mutable.HashMap[StaticObjectFieldSet, java.util.LinkedHashMap[String, java.lang.Boolean]],
internedStrings: mutable.HashMap[String, String]): Obj = {
// Set the initial capacity to the number of fields divided by the default load factor + 1 -
// this ensures that we can fill up the map to the total number of fields without resizing.
// From JavaDoc - true for both Scala & Java HashMaps
val hashMapDefaultLoadFactor = 0.75f
val capacity = (fields.length / hashMapDefaultLoadFactor).toInt + 1
val cache = mutable.HashMap.empty[Any, Val]
val allKeys = new util.LinkedHashMap[String, java.lang.Boolean]
val allKeys = new util.LinkedHashMap[String, java.lang.Boolean](capacity, hashMapDefaultLoadFactor)
val keys = new Array[String](fields.length)
var idx = 0
fields.foreach {
case Expr.Member.Field(_, Expr.FieldName.Fixed(k), _, _, _, rhs: Val.Literal) =>
cache.put(k, rhs)
allKeys.put(k, false)
val uniqueKey = internedStrings.getOrElseUpdate(k, k)
cache.put(uniqueKey, rhs)
allKeys.put(uniqueKey, false)
keys(idx) = uniqueKey
idx += 1
}
new Val.Obj(pos, null, true, null, null, cache, allKeys)
val fieldSet = new StaticObjectFieldSet(keys)
new Val.Obj(pos, null, true, null, null, cache, internedKeyMaps.getOrElseUpdate(fieldSet, allKeys))
}

abstract class Func(val pos: Position,
Expand Down
6 changes: 4 additions & 2 deletions sjsonnet/test/src/sjsonnet/ParserTests.scala
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
package sjsonnet

import scala.collection.mutable
import utest._
import Expr._
import fastparse.Parsed
import Val.{True, Num}
object ParserTests extends TestSuite{
def parse(s: String, strictImportSyntax: Boolean = false) = fastparse.parse(s, new Parser(null, strictImportSyntax).document(_)).get.value._1
def parseErr(s: String, strictImportSyntax: Boolean = false) = fastparse.parse(s, new Parser(null, strictImportSyntax).document(_), verboseFailures = true).asInstanceOf[Parsed.Failure].msg
def parse(s: String, strictImportSyntax: Boolean = false) = fastparse.parse(s, new Parser(null, strictImportSyntax, mutable.HashMap.empty, mutable.HashMap.empty).document(_)).get.value._1
def parseErr(s: String, strictImportSyntax: Boolean = false) = fastparse.parse(s, new Parser(null, strictImportSyntax, mutable.HashMap.empty, mutable.HashMap.empty).document(_), verboseFailures = true).asInstanceOf[Parsed.Failure].msg
val dummyFS = new FileScope(null)
def pos(i: Int) = new Position(dummyFS, i)
def tests = Tests{
Expand Down

0 comments on commit c462833

Please sign in to comment.