Add early mutators

The mutators that run after dependencies are resolved can be too late
to support build logic that needs to vary the dependencies based on
the mutated axis, for example architecture.  This patch provides an
EarlyMutator interface that can be used to mutate modules before
any dependencies have been resolved.

In order for dependencies to be satisifed in a later pass, all
dependencies of a module must either have an identical variant,
must have a single variant, or must be inserted using
DynamicDependencyModuleContext.AddVariantDependency.

Change-Id: Ic6ae57e98edfd6c8c09a7788983128d3e4e992f0
diff --git a/context.go b/context.go
index 8f5cedb..2b83adf 100644
--- a/context.go
+++ b/context.go
@@ -62,13 +62,15 @@
 // actions.
 type Context struct {
 	// set at instantiation
-	moduleFactories  map[string]ModuleFactory
-	moduleGroups     map[string]*moduleGroup
-	moduleInfo       map[Module]*moduleInfo
-	modulesSorted    []*moduleInfo
-	singletonInfo    map[string]*singletonInfo
-	mutatorInfo      []*mutatorInfo
-	moduleNinjaNames map[string]*moduleGroup
+	moduleFactories     map[string]ModuleFactory
+	moduleGroups        map[string]*moduleGroup
+	moduleInfo          map[Module]*moduleInfo
+	modulesSorted       []*moduleInfo
+	singletonInfo       map[string]*singletonInfo
+	mutatorInfo         []*mutatorInfo
+	earlyMutatorInfo    []*earlyMutatorInfo
+	variantMutatorNames []string
+	moduleNinjaNames    map[string]*moduleGroup
 
 	dependenciesReady bool // set to true on a successful ResolveDependencies
 	buildActionsReady bool // set to true on a successful PrepareBuildActions
@@ -123,8 +125,9 @@
 		Deps []string
 	}
 
-	variantName string
-	variants    variantMap
+	variantName        string
+	variants           variantMap
+	dependencyVariants variantMap
 
 	logicModule      Module
 	group            *moduleGroup
@@ -147,6 +150,11 @@
 	actionDefs localBuildActions
 }
 
+type Variant struct {
+	Mutator string
+	Variant string
+}
+
 type variantMap map[string]string
 
 func (vm variantMap) clone() variantMap {
@@ -178,6 +186,12 @@
 	name            string
 }
 
+type earlyMutatorInfo struct {
+	// set during RegisterEarlyMutator
+	mutator EarlyMutator
+	name    string
+}
+
 func (e *Error) Error() string {
 
 	return fmt.Sprintf("%s: %s", e.Pos, e.Err)
@@ -311,10 +325,12 @@
 
 // RegisterTopDownMutator registers a mutator that will be invoked to propagate
 // dependency info top-down between Modules.  Each registered mutator
-// is invoked once per Module, and is invoked on a module before being invoked
-// on any of its dependencies
+// is invoked in registration order (mixing TopDownMutators and BottomUpMutators)
+// once per Module, and is invoked on a module before being invoked on any of its
+// dependencies.
 //
-// The mutator type names given here must be unique for the context.
+// The mutator type names given here must be unique to all top down mutators in
+// the Context.
 func (c *Context) RegisterTopDownMutator(name string, mutator TopDownMutator) {
 	for _, m := range c.mutatorInfo {
 		if m.name == name && m.topDownMutator != nil {
@@ -329,13 +345,15 @@
 }
 
 // RegisterBottomUpMutator registers a mutator that will be invoked to split
-// Modules into variants.  Each registered mutator is invoked once per Module,
-// and is invoked on dependencies before being invoked on dependers.
+// Modules into variants.  Each registered mutator is invoked in registration
+// order (mixing TopDownMutators and BottomUpMutators) once per Module, and is
+// invoked on dependencies before being invoked on dependers.
 //
-// The mutator type names given here must be unique for the context.
+// The mutator type names given here must be unique to all bottom up or early
+// mutators in the Context.
 func (c *Context) RegisterBottomUpMutator(name string, mutator BottomUpMutator) {
-	for _, m := range c.mutatorInfo {
-		if m.name == name && m.bottomUpMutator != nil {
+	for _, m := range c.variantMutatorNames {
+		if m == name {
 			panic(fmt.Errorf("mutator name %s is already registered", name))
 		}
 	}
@@ -344,6 +362,35 @@
 		bottomUpMutator: mutator,
 		name:            name,
 	})
+
+	c.variantMutatorNames = append(c.variantMutatorNames, name)
+}
+
+// RegisterEarlyMutator registers a mutator that will be invoked to split
+// Modules into multiple variant Modules before any dependencies have been
+// created.  Each registered mutator is invoked in registration order once
+// per Module (including each variant from previous early mutators).  Module
+// order is unpredictable.
+//
+// In order for dependencies to be satisifed in a later pass, all dependencies
+// of a module either must have an identical variant or must have a single
+// variant.
+//
+// The mutator type names given here must be unique to all bottom up or early
+// mutators in the Context.
+func (c *Context) RegisterEarlyMutator(name string, mutator EarlyMutator) {
+	for _, m := range c.variantMutatorNames {
+		if m == name {
+			panic(fmt.Errorf("mutator name %s is already registered", name))
+		}
+	}
+
+	c.earlyMutatorInfo = append(c.earlyMutatorInfo, &earlyMutatorInfo{
+		mutator: mutator,
+		name:    name,
+	})
+
+	c.variantMutatorNames = append(c.variantMutatorNames, name)
 }
 
 // SetIgnoreUnknownModuleTypes sets the behavior of the context in the case
@@ -711,6 +758,7 @@
 		newModule.directDeps = append([]*moduleInfo(nil), origModule.directDeps...)
 		newModule.logicModule = newLogicModule
 		newModule.variants = newVariants
+		newModule.dependencyVariants = origModule.dependencyVariants.clone()
 		newModule.moduleProperties = newProperties
 
 		if newModule.variantName == "" {
@@ -763,6 +811,17 @@
 	return errs
 }
 
+func (c *Context) prettyPrintVariant(variant variantMap) string {
+	names := make([]string, 0, len(variant))
+	for _, m := range c.variantMutatorNames {
+		if v, ok := variant[m]; ok {
+			names = append(names, m+":"+v)
+		}
+	}
+
+	return strings.Join(names, ", ")
+}
+
 func (c *Context) processModuleDef(moduleDef *parser.Module,
 	relBlueprintsFile string) (*moduleInfo, []error) {
 
@@ -873,13 +932,14 @@
 	return nil
 }
 
-// moduleDepNames returns the sorted list of dependency names for a given
-// module.  If the module implements the DynamicDependerModule interface then
-// this set consists of the union of those module names listed in its "deps"
-// property and those returned by its DynamicDependencies method.  Otherwise it
+// moduleDeps adds dependencies to a module.  If the module implements the
+// DynamicDependerModule interface then this set consists of the union of those
+// module names listed in its "deps" property, those returned by its
+// DynamicDependencies method, and those added by calling AddDependencies or
+// AddVariantDependencies on DynamicDependencyModuleContext.  Otherwise it
 // is simply those names listed in its "deps" property.
-func (c *Context) moduleDepNames(module *moduleInfo,
-	config interface{}) ([]string, []error) {
+func (c *Context) moduleDeps(module *moduleInfo,
+	config interface{}) (errs []error) {
 
 	depNamesSet := make(map[string]bool)
 	depNames := []string{}
@@ -891,19 +951,21 @@
 		}
 	}
 
-	logicModule := module.logicModule
-	dynamicDepender, ok := logicModule.(DynamicDependerModule)
+	dynamicDepender, ok := module.logicModule.(DynamicDependerModule)
 	if ok {
-		ddmctx := &baseModuleContext{
-			context: c,
-			config:  config,
-			module:  module,
+		ddmctx := &dynamicDependerModuleContext{
+			baseModuleContext: baseModuleContext{
+				context: c,
+				config:  config,
+				module:  module,
+			},
+			module: module,
 		}
 
 		dynamicDeps := dynamicDepender.DynamicDependencies(ddmctx)
 
 		if len(ddmctx.errs) > 0 {
-			return nil, ddmctx.errs
+			return ddmctx.errs
 		}
 
 		for _, depName := range dynamicDeps {
@@ -914,29 +976,25 @@
 		}
 	}
 
-	return depNames, nil
+	for _, depName := range depNames {
+		newErrs := c.addDependency(module, depName)
+		if len(newErrs) > 0 {
+			errs = append(errs, newErrs...)
+		}
+	}
+	return errs
 }
 
-// resolveDependencies populates the moduleGroup.modules[0].directDeps list for every
-// module.  In doing so it checks for missing dependencies and self-dependant
-// modules.
+// resolveDependencies populates the directDeps list for every module.  In doing so it checks for
+// missing dependencies and self-dependant modules.
 func (c *Context) resolveDependencies(config interface{}) (errs []error) {
 	for _, group := range c.moduleGroups {
 		for _, module := range group.modules {
-			depNames, newErrs := c.moduleDepNames(module, config)
+			module.directDeps = make([]*moduleInfo, 0, len(module.properties.Deps))
+
+			newErrs := c.moduleDeps(module, config)
 			if len(newErrs) > 0 {
 				errs = append(errs, newErrs...)
-				continue
-			}
-
-			module.directDeps = make([]*moduleInfo, 0, len(depNames))
-
-			for _, depName := range depNames {
-				newErrs := c.addDependency(module, depName)
-				if len(newErrs) > 0 {
-					errs = append(errs, newErrs...)
-					continue
-				}
 			}
 		}
 	}
@@ -963,14 +1021,76 @@
 		}}
 	}
 
-	if len(depInfo.modules) != 1 {
-		panic(fmt.Sprintf("cannot add dependency from %s to %s, it already has multiple variants",
-			module.properties.Name, depInfo.modules[0].properties.Name))
+	for _, m := range module.directDeps {
+		if m.group == depInfo {
+			return nil
+		}
 	}
 
-	module.directDeps = append(module.directDeps, depInfo.modules[0])
+	if len(depInfo.modules) == 1 {
+		module.directDeps = append(module.directDeps, depInfo.modules[0])
+		return nil
+	} else {
+		for _, m := range depInfo.modules {
+			if m.variants.equal(module.dependencyVariants) {
+				module.directDeps = append(module.directDeps, m)
+				return nil
+			}
+		}
+	}
 
-	return nil
+	return []error{&Error{
+		Err: fmt.Errorf("dependency %q of %q missing variant %q",
+			depInfo.modules[0].properties.Name, module.properties.Name,
+			c.prettyPrintVariant(module.dependencyVariants)),
+		Pos: depsPos,
+	}}
+}
+
+func (c *Context) addVariantDependency(module *moduleInfo, variant []Variant,
+	depName string) []error {
+
+	depsPos := module.propertyPos["deps"]
+
+	depInfo, ok := c.moduleGroups[depName]
+	if !ok {
+		return []error{&Error{
+			Err: fmt.Errorf("%q depends on undefined module %q",
+				module.properties.Name, depName),
+			Pos: depsPos,
+		}}
+	}
+
+	// We can't just append variant.Variant to module.dependencyVariants.variantName and
+	// compare the strings because the result won't be in mutator registration order.
+	// Create a new map instead, and then deep compare the maps.
+	newVariants := module.dependencyVariants.clone()
+	for _, v := range variant {
+		newVariants[v.Mutator] = v.Variant
+	}
+
+	for _, m := range depInfo.modules {
+		if newVariants.equal(m.variants) {
+			// AddVariantDependency allows adding a dependency on itself, but only if
+			// that module is earlier in the module list than this one, since we always
+			// run the generator in order for the variants of a module
+			if depInfo == module.group && beforeInModuleList(module, m, module.group.modules) {
+				return []error{&Error{
+					Err: fmt.Errorf("%q depends on later version of itself", depName),
+					Pos: depsPos,
+				}}
+			}
+			module.directDeps = append(module.directDeps, m)
+			return nil
+		}
+	}
+
+	return []error{&Error{
+		Err: fmt.Errorf("dependency %q of %q missing variant %q",
+			depInfo.modules[0].properties.Name, module.properties.Name,
+			c.prettyPrintVariant(newVariants)),
+		Pos: depsPos,
+	}}
 }
 
 func (c *Context) parallelVisitAllBottomUp(visit func(group *moduleInfo) bool) {
@@ -1148,6 +1268,11 @@
 func (c *Context) PrepareBuildActions(config interface{}) (deps []string, errs []error) {
 	c.buildActionsReady = false
 
+	errs = c.runEarlyMutators(config)
+	if len(errs) > 0 {
+		return nil, errs
+	}
+
 	if !c.dependenciesReady {
 		errs := c.ResolveDependencies(config)
 		if len(errs) > 0 {
@@ -1195,6 +1320,40 @@
 	return deps, nil
 }
 
+func (c *Context) runEarlyMutators(config interface{}) (errs []error) {
+	for _, mutator := range c.earlyMutatorInfo {
+		for _, group := range c.moduleGroups {
+			newModules := make([]*moduleInfo, 0, len(group.modules))
+
+			for _, module := range group.modules {
+				mctx := &mutatorContext{
+					baseModuleContext: baseModuleContext{
+						context: c,
+						config:  config,
+						module:  module,
+					},
+					name: mutator.name,
+				}
+				mutator.mutator(mctx)
+				if len(mctx.errs) > 0 {
+					errs = append(errs, mctx.errs...)
+					return errs
+				}
+
+				if module.splitModules != nil {
+					newModules = append(newModules, module.splitModules...)
+				} else {
+					newModules = append(newModules, module)
+				}
+			}
+
+			group.modules = newModules
+		}
+	}
+
+	return nil
+}
+
 func (c *Context) runMutators(config interface{}) (errs []error) {
 	for _, mutator := range c.mutatorInfo {
 		if mutator.topDownMutator != nil {
@@ -2037,8 +2196,8 @@
 	iName := s[i].properties.Name
 	jName := s[j].properties.Name
 	if iName == jName {
-		iName = s[i].subName()
-		jName = s[j].subName()
+		iName = s[i].variantName
+		jName = s[j].variantName
 	}
 	return iName < jName
 }
@@ -2082,7 +2241,7 @@
 			"typeName":   module.typeName,
 			"goFactory":  factoryName,
 			"pos":        relPos,
-			"variant":    module.subName(),
+			"variant":    module.variantName,
 		}
 		err = headerTemplate.Execute(buf, infoMap)
 		if err != nil {
@@ -2235,6 +2394,23 @@
 	return nil
 }
 
+func beforeInModuleList(a, b *moduleInfo, list []*moduleInfo) bool {
+	found := false
+	for _, l := range list {
+		if l == a {
+			found = true
+		} else if l == b {
+			return found
+		}
+	}
+
+	missing := a
+	if found {
+		missing = b
+	}
+	panic(fmt.Errorf("element %v not found in list %v", missing, list))
+}
+
 var fileHeaderTemplate = `******************************************************************************
 ***            This file is generated and should not be edited             ***
 ******************************************************************************