Skip to content

Commit

Permalink
Make SourceModels constructor responsible for building the Sources
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 651096431
  • Loading branch information
goodfeli authored and Torax team committed Jul 17, 2024
1 parent d9fe95c commit f220226
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 31 deletions.
54 changes: 32 additions & 22 deletions torax/sources/source_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -652,32 +652,46 @@ class SourceModels:

def __init__(
self,
sources: dict[str, source_lib.Source] | None = None,
source_builders: (
dict[str, source_lib.SourceBuilderProtocol] | None
) = None,
):
"""Constructs a collection of sources.
The constructor should only be called by SourceModelsBuilder.
This class defines which sources are available in a TORAX simulation run.
Users can configure whether each source is actually on and what kind of
profile it produces by changing its runtime configuration (see
runtime_params_lib.py).
Args:
sources: Mapping of source model names to the Source objects. The names
(i.e. the keys of this dictionary) also define the keys in the output
SourceProfiles which are computed from this SourceModels object. NOTE -
Some sources are "special-case": bootstrap current, external current,
and Qei. SourceModels will always instantiate default objects for these
types of sources unless they are provided by this `sources` argument.
Also, their default names are reserved, meaning the input dictionary
`sources` should not have the keys 'j_bootstrap', 'jext', or
'qei_source' unless those sources are one of these "special-case"
sources.
source_builders: Mapping of source model names to builders of the Source
objects. The names (i.e. the keys of this dictionary) also define the
keys in the output SourceProfiles which are computed from this
SourceModels object. NOTE - Some sources are "special-case": bootstrap
current, external current, and Qei. SourceModels will always instantiate
default objects for these types of sources unless they are provided by
this `sources` argument. Also, their default names are reserved, meaning
the input dictionary `sources` should not have the keys 'j_bootstrap',
'jext', or 'qei_source' unless those sources are one of these
"special-case" sources.
Raises:
ValueError if there is a naming collision with the reserved names as
described above.
"""
sources = sources or {}

source_builders = source_builders or {}

# Begin initial construction with sources that don't link back to the
# SourceModels
sources = {
name: builder()
for name, builder in source_builders.items()
if not builder.links_back
}

# Some sources are accessed for specific use cases, so we extract those
# ones and expose them directly.
self._j_bootstrap = None
Expand Down Expand Up @@ -741,6 +755,11 @@ def __init__(
else:
self.add_source(source_name, source)

# Now add the sources that link back
for name, builder in source_builders.items():
if builder.links_back:
self.add_source(name, builder(self))

def add_source(
self,
source_name: str,
Expand Down Expand Up @@ -931,16 +950,7 @@ def __init__(

def __call__(self) -> SourceModels:

unlinked_sources = {
name: builder()
for name, builder in self.source_builders.items()
if not builder.links_back
}
initial_model = SourceModels(unlinked_sources)
for name, builder in self.source_builders.items():
if builder.links_back:
initial_model.add_source(name, builder(initial_model))
return initial_model
return SourceModels(self.source_builders)

@property
def runtime_params(self) -> dict[str, runtime_params_lib.RuntimeParams]:
Expand Down
20 changes: 11 additions & 9 deletions torax/sources/tests/source_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,22 +215,22 @@ def test_cannot_add_multiple_special_case_sources(self):
with self.assertRaises(ValueError):
source_models_lib.SourceModels(
dict(
j_bootstrap=bootstrap_current_source.BootstrapCurrentSource(),
j_bootstrap2=bootstrap_current_source.BootstrapCurrentSource(),
j_bootstrap=bootstrap_current_source.BootstrapCurrentSourceBuilder(),
j_bootstrap2=bootstrap_current_source.BootstrapCurrentSourceBuilder(),
)
)
with self.assertRaises(ValueError):
source_models_lib.SourceModels(
dict(
qei=qei_source.QeiSource(),
qei2=qei_source.QeiSource(),
qei=qei_source.QeiSourceBuilder(),
qei2=qei_source.QeiSourceBuilder(),
)
)
with self.assertRaises(ValueError):
source_models_lib.SourceModels(
dict(
external_current=external_current_source.ExternalCurrentSource(),
external_current2=external_current_source.ExternalCurrentSource(),
external_current=external_current_source.ExternalCurrentSourceBuilder(),
external_current2=external_current_source.ExternalCurrentSourceBuilder(),
)
)
source_models = source_models_lib.SourceModels()
Expand All @@ -248,16 +248,18 @@ def test_cannot_add_multiple_special_case_sources(self):
def test_cannot_add_multiple_sources_with_same_name(self):
"""Tests that SourceModels cannot add multiple sources with same name."""
source_name = 'foo'
foo_source = source_lib.Source(
foo_source_builder = source_lib.SourceBuilder(
affected_core_profiles=(source_lib.AffectedCoreProfile.TEMP_EL,),
supported_modes=(runtime_params_lib.Mode.ZERO,),
)
source_models = source_models_lib.SourceModels(
sources={source_name: foo_source},
source_builders={source_name: foo_source_builder},
)
# It's built once by the SourceModels constructor
rebuilt = foo_source_builder()
# Cannot add another source with that name again.
with self.assertRaises(ValueError):
source_models.add_source(source_name, foo_source)
source_models.add_source(source_name, rebuilt)


if __name__ == '__main__':
Expand Down

0 comments on commit f220226

Please sign in to comment.