From c462833b6144ee673772e80f6446293ae0ce5e6d Mon Sep 17 00:00:00 2001 From: Ahir Reddy Date: Tue, 9 Jan 2024 05:47:31 -0800 Subject: [PATCH] Static Object Optimizations (#197) - This PR uses the Parser & StaticOptimizer thread local string interner for keys in static objects - Similarly, we deduplicate the String -> Boolean map used to determine if a field is static. - For static objects we also use immutable.VectorMap (with a JavaWrapper) for the field set. - Lastly for the value cache, we size it according to the number of keys in the object this reduces unnecessary up-sizing for large objects, and more importantly removes the large number of sparse maps we previously had for small objects (the default was 16 elements) Before: 855MB for the parsed file ![image (9)](https://github.com/databricks/sjsonnet/assets/153361/285bec71-8704-4e0b-98cd-336332e1d22c) After: 425MB ![image (10)](https://github.com/databricks/sjsonnet/assets/153361/127f4c76-429c-4049-b881-214942368a71) --------- Co-authored-by: Li Haoyi <32282535+lihaoyi-databricks@users.noreply.github.com> --- .../main/scala/sjsonnet/MainBenchmark.scala | 54 +++++++++++++++++++ .../scala/sjsonnet/OptimizerBenchmark.scala | 8 +-- .../main/scala/sjsonnet/ParserBenchmark.scala | 4 +- readme.md | 13 +++++ sjsonnet/src/sjsonnet/Importer.scala | 6 ++- sjsonnet/src/sjsonnet/Interpreter.scala | 10 +++- sjsonnet/src/sjsonnet/Parser.scala | 10 ++-- sjsonnet/src/sjsonnet/StaticOptimizer.scala | 10 +++- sjsonnet/src/sjsonnet/Val.scala | 41 ++++++++++++-- sjsonnet/test/src/sjsonnet/ParserTests.scala | 6 ++- 10 files changed, 140 insertions(+), 22 deletions(-) diff --git a/bench/src/main/scala/sjsonnet/MainBenchmark.scala b/bench/src/main/scala/sjsonnet/MainBenchmark.scala index 7d541de5..8c1c37ec 100644 --- a/bench/src/main/scala/sjsonnet/MainBenchmark.scala +++ b/bench/src/main/scala/sjsonnet/MainBenchmark.scala @@ -66,3 +66,57 @@ class MainBenchmark { )) } } + +// This is a dummy benchmark to see how much memory is used by the interpreter. +// You're meant to execute it, and it will generate stats about memory usage before exiting. +// dump. You can optionally pass an argument to instruct it to pause post run - then attach a +// profiler. +object MemoryBenchmark { + + val dummyOut = MainBenchmark.createDummyOut + + val cache = new DefaultParseCache + + def main(args: Array[String]): Unit = { + assert(args.length <= 1, s"Too many arguments: ${args.mkString(",")}") + val pause: Boolean = if (args.length == 1) { + if (args(0) == "--pause") { + println("Run will pause after completion. Attach a profiler then.") + true + } else { + println("Unknown argument: " + args(0)) + System.exit(1) + false + } + } else { + false + } + SjsonnetMain.main0( + MainBenchmark.mainArgs, + cache, + System.in, + dummyOut, + System.err, + os.pwd, + None + ) + println("Pre-GC Stats") + println("============") + println("Total memory: " + Runtime.getRuntime.totalMemory()) + println("Free memory: " + Runtime.getRuntime.freeMemory()) + println("Used memory: " + (Runtime.getRuntime.totalMemory() - Runtime.getRuntime.freeMemory())) + System.gc() + // Wait for GC to finish + Thread.sleep(5000) + println("Post-GC Stats") + println("============") + println("Total memory: " + Runtime.getRuntime.totalMemory()) + println("Free memory: " + Runtime.getRuntime.freeMemory()) + println("Used memory: " + (Runtime.getRuntime.totalMemory() - Runtime.getRuntime.freeMemory())) + + if (pause) { + println("Pausing. Attach a profiler") + Thread.sleep(1000000000) + } + } +} diff --git a/bench/src/main/scala/sjsonnet/OptimizerBenchmark.scala b/bench/src/main/scala/sjsonnet/OptimizerBenchmark.scala index ac39dbfa..2e95de54 100644 --- a/bench/src/main/scala/sjsonnet/OptimizerBenchmark.scala +++ b/bench/src/main/scala/sjsonnet/OptimizerBenchmark.scala @@ -3,6 +3,8 @@ package sjsonnet import java.io.StringWriter import java.util.concurrent.TimeUnit +import scala.collection.mutable + import fastparse.Parsed.Success import org.openjdk.jmh.annotations._ import org.openjdk.jmh.infra._ @@ -26,13 +28,13 @@ class OptimizerBenchmark { def setup(): Unit = { val (allFiles, ev) = MainBenchmark.findFiles() this.inputs = allFiles.map { case (p, s) => - fastparse.parse(s, new Parser(p, true).document(_)) match { + fastparse.parse(s, new Parser(p, true, mutable.HashMap.empty, mutable.HashMap.empty).document(_)) match { case Success(v, _) => v } } this.ev = ev val static = inputs.map { - case (expr, fs) => ((new StaticOptimizer(ev, new Std().Std)).optimize(expr), fs) + case (expr, fs) => ((new StaticOptimizer(ev, new Std().Std, mutable.HashMap.empty, mutable.HashMap.empty)).optimize(expr), fs) } val countBefore, countStatic = new Counter inputs.foreach(t => assert(countBefore.transform(t._1) eq t._1)) @@ -45,7 +47,7 @@ class OptimizerBenchmark { @Benchmark def main(bh: Blackhole): Unit = { bh.consume(inputs.foreach { case (expr, fs) => - bh.consume((new StaticOptimizer(ev, new Std().Std)).optimize(expr)) + bh.consume((new StaticOptimizer(ev, new Std().Std, mutable.HashMap.empty, mutable.HashMap.empty)).optimize(expr)) }) } diff --git a/bench/src/main/scala/sjsonnet/ParserBenchmark.scala b/bench/src/main/scala/sjsonnet/ParserBenchmark.scala index 50323319..6100534d 100644 --- a/bench/src/main/scala/sjsonnet/ParserBenchmark.scala +++ b/bench/src/main/scala/sjsonnet/ParserBenchmark.scala @@ -2,6 +2,8 @@ package sjsonnet import java.util.concurrent.TimeUnit +import scala.collection.mutable.HashMap + import fastparse.Parsed.Success import org.openjdk.jmh.annotations._ import org.openjdk.jmh.infra._ @@ -25,7 +27,7 @@ class ParserBenchmark { @Benchmark def main(bh: Blackhole): Unit = { bh.consume(allFiles.foreach { case (p, s) => - val res = fastparse.parse(s, new Parser(p, true).document(_)) + val res = fastparse.parse(s, new Parser(p, true, HashMap.empty, HashMap.empty).document(_)) bh.consume(res.asInstanceOf[Success[_]]) }) } diff --git a/readme.md b/readme.md index 44c8b700..37bc82c0 100644 --- a/readme.md +++ b/readme.md @@ -194,6 +194,19 @@ Profiler: sbt bench/run ``` +There's also a benchmark for memory usage: + +Execute and print stats: +``` +sbt 'set fork in run := true' 'set javaOptions in run ++= Seq("-Xmx6G", "-XX:+UseG1GC")' 'bench/runMain sjsonnet.MemoryBenchmark' +``` + +Execute and pause - this is useful if you want to attach a profiler after the run and deep dive the +object utilization. +``` +sbt 'set fork in run := true' 'set javaOptions in run ++= Seq("-Xmx6G", "-XX:+UseG1GC")' 'bench/runMain sjsonnet.MemoryBenchmark --pause' +``` + ## Laziness The Jsonnet language is _lazy_: expressions don't get evaluated unless diff --git a/sjsonnet/src/sjsonnet/Importer.scala b/sjsonnet/src/sjsonnet/Importer.scala index 3051bf92..15af290c 100644 --- a/sjsonnet/src/sjsonnet/Importer.scala +++ b/sjsonnet/src/sjsonnet/Importer.scala @@ -180,11 +180,13 @@ class CachedImporter(parent: Importer) extends Importer { class CachedResolver( parentImporter: Importer, val parseCache: ParseCache, - strictImportSyntax: Boolean) extends CachedImporter(parentImporter) { + strictImportSyntax: Boolean, + internedStrings: mutable.HashMap[String, String], + internedStaticFieldSets: mutable.HashMap[Val.StaticObjectFieldSet, java.util.LinkedHashMap[String, java.lang.Boolean]]) extends CachedImporter(parentImporter) { def parse(path: Path, content: ResolvedFile)(implicit ev: EvalErrorScope): Either[Error, (Expr, FileScope)] = { parseCache.getOrElseUpdate((path, content.contentHash.toString), { - val parsed = fastparse.parse(content.getParserInput(), new Parser(path, strictImportSyntax).document(_)) match { + val parsed = fastparse.parse(content.getParserInput(), new Parser(path, strictImportSyntax, internedStrings, internedStaticFieldSets).document(_)) match { case f @ Parsed.Failure(_, _, _) => val traced = f.trace() val pos = new Position(new FileScope(path), traced.index) diff --git a/sjsonnet/src/sjsonnet/Interpreter.scala b/sjsonnet/src/sjsonnet/Interpreter.scala index c3608f75..3de38fa6 100644 --- a/sjsonnet/src/sjsonnet/Interpreter.scala +++ b/sjsonnet/src/sjsonnet/Interpreter.scala @@ -2,6 +2,8 @@ package sjsonnet import java.io.{PrintWriter, StringWriter} +import scala.collection.mutable + import sjsonnet.Expr.Params import scala.util.control.NonFatal @@ -21,9 +23,13 @@ class Interpreter(extVars: Map[String, String], std: Val.Obj = new Std().Std ) { self => - val resolver = new CachedResolver(importer, parseCache, settings.strictImportSyntax) { + private val internedStrings = new mutable.HashMap[String, String] + + private val internedStaticFieldSets = new mutable.HashMap[Val.StaticObjectFieldSet, java.util.LinkedHashMap[String, java.lang.Boolean]] + + val resolver = new CachedResolver(importer, parseCache, settings.strictImportSyntax, internedStrings, internedStaticFieldSets) { override def process(expr: Expr, fs: FileScope): Either[Error, (Expr, FileScope)] = - handleException(new StaticOptimizer(evaluator, std).optimize(expr), fs) + handleException(new StaticOptimizer(evaluator, std, internedStrings, internedStaticFieldSets).optimize(expr), fs) } private def warn(e: Error): Unit = warnLogger("[warning] " + formatError(e)) diff --git a/sjsonnet/src/sjsonnet/Parser.scala b/sjsonnet/src/sjsonnet/Parser.scala index bec8f8a6..3db4fb85 100644 --- a/sjsonnet/src/sjsonnet/Parser.scala +++ b/sjsonnet/src/sjsonnet/Parser.scala @@ -45,13 +45,13 @@ object Parser { } class Parser(val currentFile: Path, - strictImportSyntax: Boolean) { + strictImportSyntax: Boolean, + internedStrings: mutable.HashMap[String, String], + internedStaticFieldSets: mutable.HashMap[Val.StaticObjectFieldSet, java.util.LinkedHashMap[String, java.lang.Boolean]]) { import Parser._ private val fileScope = new FileScope(currentFile) - private val strings = new mutable.HashMap[String, String] - def Pos[_: P]: P[Position] = Index.map(offset => new Position(fileScope, offset)) def id[_: P] = P( @@ -264,7 +264,7 @@ class Parser(val currentFile: Path, def constructString(pos: Position, lines: Seq[String]) = { val s = lines.mkString - val unique = strings.getOrElseUpdate(s, s) + val unique = internedStrings.getOrElseUpdate(s, s) Val.Str(pos, unique) } @@ -335,7 +335,7 @@ class Parser(val currentFile: Path, val a = exprs.iterator.filter(_.isInstanceOf[Expr.Member.AssertStmt]).asInstanceOf[Iterator[Expr.Member.AssertStmt]].toArray if(a.isEmpty) null else a } - if(binds == null && asserts == null && fields.forall(_.isStatic)) Val.staticObject(pos, fields) + if(binds == null && asserts == null && fields.forall(_.isStatic)) Val.staticObject(pos, fields, internedStaticFieldSets, internedStrings) else Expr.ObjBody.MemberList(pos, binds, fields, asserts) case (pos, exprs, Some(comps)) => val preLocals = exprs diff --git a/sjsonnet/src/sjsonnet/StaticOptimizer.scala b/sjsonnet/src/sjsonnet/StaticOptimizer.scala index 4987d048..6e1e2854 100644 --- a/sjsonnet/src/sjsonnet/StaticOptimizer.scala +++ b/sjsonnet/src/sjsonnet/StaticOptimizer.scala @@ -1,11 +1,17 @@ package sjsonnet +import scala.collection.mutable + import Expr._ import ScopedExprTransform._ /** StaticOptimizer performs necessary transformations for the evaluator (assigning ValScope * indices) plus additional optimizations (post-order) and static checking (pre-order). */ -class StaticOptimizer(ev: EvalScope, std: Val.Obj) extends ScopedExprTransform { +class StaticOptimizer( + ev: EvalScope, + std: Val.Obj, + internedStrings: mutable.HashMap[String, String], + internedStaticFieldSets: mutable.HashMap[Val.StaticObjectFieldSet, java.util.LinkedHashMap[String, java.lang.Boolean]]) extends ScopedExprTransform { def optimize(e: Expr): Expr = transform(e) def failOrWarn(msg: String, expr: Expr): Expr = { @@ -61,7 +67,7 @@ class StaticOptimizer(ev: EvalScope, std: Val.Obj) extends ScopedExprTransform { new Val.Arr(a.pos, a.value.map(e => e.asInstanceOf[Val])) case m @ ObjBody.MemberList(pos, binds, fields, asserts) => - if(binds == null && asserts == null && fields.forall(_.isStatic)) Val.staticObject(pos, fields) + if(binds == null && asserts == null && fields.forall(_.isStatic)) Val.staticObject(pos, fields, internedStaticFieldSets, internedStrings) else m case e => e diff --git a/sjsonnet/src/sjsonnet/Val.scala b/sjsonnet/src/sjsonnet/Val.scala index f8e61a1a..db1ac04a 100644 --- a/sjsonnet/src/sjsonnet/Val.scala +++ b/sjsonnet/src/sjsonnet/Val.scala @@ -1,6 +1,7 @@ package sjsonnet import java.util +import java.util.Arrays import sjsonnet.Expr.Member.Visibility import sjsonnet.Expr.Params @@ -297,15 +298,45 @@ object Val{ } } - def staticObject(pos: Position, fields: Array[Expr.Member.Field]): Obj = { + final class StaticObjectFieldSet(protected val keys: Array[String]) { + + override def hashCode(): Int = { + Arrays.hashCode(keys.asInstanceOf[Array[Object]]) + } + + override def equals(obj: scala.Any): Boolean = { + obj match { + case that: StaticObjectFieldSet => + keys.sameElements(that.keys) + case _ => false + } + } + } + + def staticObject( + pos: Position, + fields: Array[Expr.Member.Field], + internedKeyMaps: mutable.HashMap[StaticObjectFieldSet, java.util.LinkedHashMap[String, java.lang.Boolean]], + internedStrings: mutable.HashMap[String, String]): Obj = { + // Set the initial capacity to the number of fields divided by the default load factor + 1 - + // this ensures that we can fill up the map to the total number of fields without resizing. + // From JavaDoc - true for both Scala & Java HashMaps + val hashMapDefaultLoadFactor = 0.75f + val capacity = (fields.length / hashMapDefaultLoadFactor).toInt + 1 val cache = mutable.HashMap.empty[Any, Val] - val allKeys = new util.LinkedHashMap[String, java.lang.Boolean] + val allKeys = new util.LinkedHashMap[String, java.lang.Boolean](capacity, hashMapDefaultLoadFactor) + val keys = new Array[String](fields.length) + var idx = 0 fields.foreach { case Expr.Member.Field(_, Expr.FieldName.Fixed(k), _, _, _, rhs: Val.Literal) => - cache.put(k, rhs) - allKeys.put(k, false) + val uniqueKey = internedStrings.getOrElseUpdate(k, k) + cache.put(uniqueKey, rhs) + allKeys.put(uniqueKey, false) + keys(idx) = uniqueKey + idx += 1 } - new Val.Obj(pos, null, true, null, null, cache, allKeys) + val fieldSet = new StaticObjectFieldSet(keys) + new Val.Obj(pos, null, true, null, null, cache, internedKeyMaps.getOrElseUpdate(fieldSet, allKeys)) } abstract class Func(val pos: Position, diff --git a/sjsonnet/test/src/sjsonnet/ParserTests.scala b/sjsonnet/test/src/sjsonnet/ParserTests.scala index 8de38b04..37da8c5b 100644 --- a/sjsonnet/test/src/sjsonnet/ParserTests.scala +++ b/sjsonnet/test/src/sjsonnet/ParserTests.scala @@ -1,11 +1,13 @@ package sjsonnet + +import scala.collection.mutable import utest._ import Expr._ import fastparse.Parsed import Val.{True, Num} object ParserTests extends TestSuite{ - def parse(s: String, strictImportSyntax: Boolean = false) = fastparse.parse(s, new Parser(null, strictImportSyntax).document(_)).get.value._1 - def parseErr(s: String, strictImportSyntax: Boolean = false) = fastparse.parse(s, new Parser(null, strictImportSyntax).document(_), verboseFailures = true).asInstanceOf[Parsed.Failure].msg + def parse(s: String, strictImportSyntax: Boolean = false) = fastparse.parse(s, new Parser(null, strictImportSyntax, mutable.HashMap.empty, mutable.HashMap.empty).document(_)).get.value._1 + def parseErr(s: String, strictImportSyntax: Boolean = false) = fastparse.parse(s, new Parser(null, strictImportSyntax, mutable.HashMap.empty, mutable.HashMap.empty).document(_), verboseFailures = true).asInstanceOf[Parsed.Failure].msg val dummyFS = new FileScope(null) def pos(i: Int) = new Position(dummyFS, i) def tests = Tests{