blob: 645f0c1a5ff2c11ff4bffb23656706c4c0ad81a2 [file] [log] [blame]
Damien Neild025c952020-02-02 00:53:34 -08001// Copyright 2020 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5// Package protobuild constructs messages.
6//
7// This package is used to construct multiple types of message with a similar shape
8// from a common template.
9package protobuild
10
11import (
12 "fmt"
13 "math"
14 "reflect"
15
16 pref "google.golang.org/protobuf/reflect/protoreflect"
17 "google.golang.org/protobuf/reflect/protoregistry"
18)
19
20// A Value is a value assignable to a field.
21// A Value may be a value accepted by protoreflect.ValueOf. In addition:
22//
23// • An int may be assigned to any numeric field.
24//
25// • A float64 may be assigned to a double field.
26//
27// • Either a string or []byte may be assigned to a string or bytes field.
28//
29// • A string containing the value name may be assigned to an enum field.
30//
31// • A slice may be assigned to a list, and a map may be assigned to a map.
32type Value interface{}
33
34// A Message is a template to apply to a message. Keys are field names, including
35// extension names.
36type Message map[pref.Name]Value
37
38// Unknown is a key associated with the unknown fields of a message.
39// The value should be a []byte.
40const Unknown = "@unknown"
41
42// Build applies the template to a message.
43func (template Message) Build(m pref.Message) {
44 md := m.Descriptor()
45 fields := md.Fields()
46 exts := make(map[pref.Name]pref.FieldDescriptor)
47 protoregistry.GlobalTypes.RangeExtensionsByMessage(md.FullName(), func(xt pref.ExtensionType) bool {
48 xd := xt.TypeDescriptor()
49 exts[xd.Name()] = xd
50 return true
51 })
52 for k, v := range template {
53 if k == Unknown {
54 m.SetUnknown(pref.RawFields(v.([]byte)))
55 continue
56 }
57 fd := fields.ByName(k)
58 if fd == nil {
59 fd = exts[k]
60 }
61 if fd == nil {
62 panic(fmt.Sprintf("%v.%v: not found", md.FullName(), k))
63 }
64 switch {
65 case fd.IsList():
66 list := m.Mutable(fd).List()
67 s := reflect.ValueOf(v)
68 for i := 0; i < s.Len(); i++ {
69 if fd.Message() == nil {
70 list.Append(fieldValue(fd, s.Index(i).Interface()))
71 } else {
72 e := list.NewElement()
73 s.Index(i).Interface().(Message).Build(e.Message())
74 list.Append(e)
75 }
76 }
77 case fd.IsMap():
78 mapv := m.Mutable(fd).Map()
79 rm := reflect.ValueOf(v)
80 for _, k := range rm.MapKeys() {
81 mk := fieldValue(fd.MapKey(), k.Interface()).MapKey()
82 if fd.MapValue().Message() == nil {
83 mv := fieldValue(fd.MapValue(), rm.MapIndex(k).Interface())
84 mapv.Set(mk, mv)
85 } else if mapv.Has(mk) {
86 mv := mapv.Get(mk).Message()
87 rm.MapIndex(k).Interface().(Message).Build(mv)
88 } else {
89 mv := mapv.NewValue()
90 rm.MapIndex(k).Interface().(Message).Build(mv.Message())
91 mapv.Set(mk, mv)
92 }
93 }
94 default:
95 if fd.Message() == nil {
96 m.Set(fd, fieldValue(fd, v))
97 } else {
98 v.(Message).Build(m.Mutable(fd).Message())
99 }
100 }
101
102 }
103}
104
105func fieldValue(fd pref.FieldDescriptor, v interface{}) pref.Value {
106 switch o := v.(type) {
107 case int:
108 switch fd.Kind() {
109 case pref.Int32Kind, pref.Sint32Kind, pref.Sfixed32Kind:
110 if min, max := math.MinInt32, math.MaxInt32; o < min || o > max {
111 panic(fmt.Sprintf("%v: value %v out of range [%v, %v]", fd.FullName(), o, min, max))
112 }
113 v = int32(o)
114 case pref.Uint32Kind, pref.Fixed32Kind:
115 if min, max := 0, math.MaxUint32; o < min || o > max {
116 panic(fmt.Sprintf("%v: value %v out of range [%v, %v]", fd.FullName(), o, min, max))
117 }
118 v = uint32(o)
119 case pref.Int64Kind, pref.Sint64Kind, pref.Sfixed64Kind:
120 v = int64(o)
121 case pref.Uint64Kind, pref.Fixed64Kind:
122 if o < 0 {
123 panic(fmt.Sprintf("%v: value %v out of range [%v, %v]", fd.FullName(), o, 0, uint64(math.MaxUint64)))
124 }
125 v = uint64(o)
126 case pref.FloatKind:
127 v = float32(o)
128 case pref.DoubleKind:
129 v = float64(o)
130 case pref.EnumKind:
131 v = pref.EnumNumber(o)
132 default:
133 panic(fmt.Sprintf("%v: invalid value type int", fd.FullName()))
134 }
135 case float64:
136 switch fd.Kind() {
137 case pref.FloatKind:
138 v = float32(o)
139 }
140 case string:
141 switch fd.Kind() {
142 case pref.BytesKind:
143 v = []byte(o)
144 case pref.EnumKind:
145 v = fd.Enum().Values().ByName(pref.Name(o)).Number()
146 }
147 }
148 return pref.ValueOf(v)
149}