reflect/protoreflect: add alternative message reflection API

Added API:
	Message.Len
	Message.Range
	Message.Has
	Message.Clear
	Message.Get
	Message.Set
	Message.Mutable
	Message.NewMessage
	Message.WhichOneof
	Message.GetUnknown
	Message.SetUnknown

Deprecated API (to be removed in subsequent CL):
	Message.KnownFields
	Message.UnknownFields

The primary difference with the new API is that the top-level
Message methods are keyed by FieldDescriptor rather than FieldNumber
with the following semantics:
* For known fields, the FieldDescriptor must exactly match the
field descriptor known by the message.
* For extension fields, the FieldDescriptor must implement ExtensionType,
where ContainingMessage.FullName matches the message name, and
the field number is within the message's extension range.
When setting an extension field, it automatically stores
the extension type information.
* Extension fields are always considered nullable,
implying that repeated extension fields are nullable.
That is, you can distinguish between a unpopulated list and an empty list.
* Message.Get always returns a valid Value even if unpopulated.
The behavior is already well-defined for scalars, but for unpopulated
composite types, it now returns an empty read-only version of it.

Change-Id: Ia120630b4db221aeaaf743d0f64160e1a61a0f61
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/175458
Reviewed-by: Damien Neil <dneil@google.com>
diff --git a/encoding/bench_test.go b/encoding/bench_test.go
index ab0fbbd..3b64aab 100644
--- a/encoding/bench_test.go
+++ b/encoding/bench_test.go
@@ -44,31 +44,28 @@
 		return
 	}
 
-	knownFields := m.KnownFields()
 	fieldDescs := m.Descriptor().Fields()
 	for i := 0; i < fieldDescs.Len(); i++ {
 		fd := fieldDescs.Get(i)
-		num := fd.Number()
 		switch {
 		case fd.IsList():
-			setList(knownFields.Get(num).List(), fd, level)
+			setList(m.Mutable(fd).List(), fd, level)
 		case fd.IsMap():
-			setMap(knownFields.Get(num).Map(), fd, level)
+			setMap(m.Mutable(fd).Map(), fd, level)
 		default:
-			setScalarField(knownFields, fd, level)
+			setScalarField(m, fd, level)
 		}
 	}
 }
 
-func setScalarField(knownFields pref.KnownFields, fd pref.FieldDescriptor, level int) {
-	num := fd.Number()
+func setScalarField(m pref.Message, fd pref.FieldDescriptor, level int) {
 	switch fd.Kind() {
 	case pref.MessageKind, pref.GroupKind:
-		m := knownFields.NewMessage(num)
-		fillMessage(m, level+1)
-		knownFields.Set(num, pref.ValueOf(m))
+		m2 := m.NewMessage(fd)
+		fillMessage(m2, level+1)
+		m.Set(fd, pref.ValueOf(m2))
 	default:
-		knownFields.Set(num, scalarField(fd.Kind()))
+		m.Set(fd, scalarField(fd.Kind()))
 	}
 }
 
diff --git a/encoding/protojson/decode.go b/encoding/protojson/decode.go
index a40d1e2..29bf114 100644
--- a/encoding/protojson/decode.go
+++ b/encoding/protojson/decode.go
@@ -52,11 +52,10 @@
 // setting the fields. If it returns an error, the given message may be
 // partially set.
 func (o UnmarshalOptions) Unmarshal(b []byte, m proto.Message) error {
-	mr := m.ProtoReflect()
 	// TODO: Determine if we would like to have an option for merging or only
-	// have merging behavior.  We should at least be consistent with textproto
+	// have merging behavior. We should at least be consistent with textproto
 	// marshaling.
-	resetMessage(mr)
+	proto.Reset(m)
 
 	if o.Resolver == nil {
 		o.Resolver = protoregistry.GlobalTypes
@@ -64,7 +63,7 @@
 	o.decoder = json.NewDecoder(b)
 
 	var nerr errors.NonFatal
-	if err := o.unmarshalMessage(mr, false); !nerr.Merge(err) {
+	if err := o.unmarshalMessage(m.ProtoReflect(), false); !nerr.Merge(err) {
 		return err
 	}
 
@@ -83,25 +82,6 @@
 	return nerr.E
 }
 
-// resetMessage clears all fields of given protoreflect.Message.
-func resetMessage(m pref.Message) {
-	knownFields := m.KnownFields()
-	knownFields.Range(func(num pref.FieldNumber, _ pref.Value) bool {
-		knownFields.Clear(num)
-		return true
-	})
-	unknownFields := m.UnknownFields()
-	unknownFields.Range(func(num pref.FieldNumber, _ pref.RawFields) bool {
-		unknownFields.Set(num, nil)
-		return true
-	})
-	extTypes := knownFields.ExtensionTypes()
-	extTypes.Range(func(xt pref.ExtensionType) bool {
-		extTypes.Remove(xt)
-		return true
-	})
-}
-
 // unexpectedJSONError is an error that contains the unexpected json.Value. This
 // is returned by methods to provide callers the read json.Value that it did not
 // expect.
@@ -164,9 +144,7 @@
 	var seenOneofs set.Ints
 
 	messageDesc := m.Descriptor()
-	knownFields := m.KnownFields()
 	fieldDescs := messageDesc.Fields()
-	xtTypes := knownFields.ExtensionTypes()
 
 Loop:
 	for {
@@ -200,20 +178,12 @@
 		var fd pref.FieldDescriptor
 		if strings.HasPrefix(name, "[") && strings.HasSuffix(name, "]") {
 			// Only extension names are in [name] format.
-			xtName := pref.FullName(name[1 : len(name)-1])
-			xt := xtTypes.ByName(xtName)
-			if xt == nil {
-				xt, err = o.findExtension(xtName)
-				if err != nil && err != protoregistry.NotFound {
-					return errors.New("unable to resolve [%v]: %v", xtName, err)
-				}
-				if xt != nil {
-					xtTypes.Register(xt)
-				}
+			extName := pref.FullName(name[1 : len(name)-1])
+			extType, err := o.findExtension(extName)
+			if err != nil && err != protoregistry.NotFound {
+				return errors.New("unable to resolve [%v]: %v", extName, err)
 			}
-			if xt != nil {
-				fd = xt.Descriptor()
-			}
+			fd = extType
 		} else {
 			// The name can either be the JSON name or the proto field name.
 			fd = fieldDescs.ByJSONName(name)
@@ -249,12 +219,12 @@
 
 		switch {
 		case fd.IsList():
-			list := knownFields.Get(fd.Number()).List()
+			list := m.Mutable(fd).List()
 			if err := o.unmarshalList(list, fd); !nerr.Merge(err) {
 				return errors.New("%v|%q: %v", fd.FullName(), name, err)
 			}
 		case fd.IsMap():
-			mmap := knownFields.Get(fd.Number()).Map()
+			mmap := m.Mutable(fd).Map()
 			if err := o.unmarshalMap(mmap, fd); !nerr.Merge(err) {
 				return errors.New("%v|%q: %v", fd.FullName(), name, err)
 			}
@@ -269,7 +239,7 @@
 			}
 
 			// Required or optional fields.
-			if err := o.unmarshalSingular(knownFields, fd); !nerr.Merge(err) {
+			if err := o.unmarshalSingular(m, fd); !nerr.Merge(err) {
 				return errors.New("%v|%q: %v", fd.FullName(), name, err)
 			}
 		}
@@ -305,16 +275,14 @@
 
 // unmarshalSingular unmarshals to the non-repeated field specified by the given
 // FieldDescriptor.
-func (o UnmarshalOptions) unmarshalSingular(knownFields pref.KnownFields, fd pref.FieldDescriptor) error {
+func (o UnmarshalOptions) unmarshalSingular(m pref.Message, fd pref.FieldDescriptor) error {
 	var val pref.Value
 	var err error
-	num := fd.Number()
-
 	switch fd.Kind() {
 	case pref.MessageKind, pref.GroupKind:
-		m := knownFields.NewMessage(num)
-		err = o.unmarshalMessage(m, false)
-		val = pref.ValueOf(m)
+		m2 := m.NewMessage(fd)
+		err = o.unmarshalMessage(m2, false)
+		val = pref.ValueOf(m2)
 	default:
 		val, err = o.unmarshalScalar(fd)
 	}
@@ -323,7 +291,7 @@
 	if !nerr.Merge(err) {
 		return err
 	}
-	knownFields.Set(num, val)
+	m.Set(fd, val)
 	return nerr.E
 }
 
diff --git a/encoding/protojson/decode_test.go b/encoding/protojson/decode_test.go
index b713a38..e3427b8 100644
--- a/encoding/protojson/decode_test.go
+++ b/encoding/protojson/decode_test.go
@@ -1371,19 +1371,6 @@
 			return m
 		}(),
 	}, {
-		desc:         "extension field set to null",
-		inputMessage: &pb2.Extensions{},
-		inputText: `{
-  "[pb2.ExtensionsContainer.opt_ext_bool]": null,
-  "[pb2.ExtensionsContainer.opt_ext_nested]": null
-}`,
-		wantMessage: func() proto.Message {
-			m := &pb2.Extensions{}
-			setExtension(m, pb2.E_ExtensionsContainer_OptExtBool, nil)
-			setExtension(m, pb2.E_ExtensionsContainer_OptExtNested, nil)
-			return m
-		}(),
-	}, {
 		desc:         "extensions of repeated field contains null",
 		inputMessage: &pb2.Extensions{},
 		inputText: `{
diff --git a/encoding/protojson/encode.go b/encoding/protojson/encode.go
index d789986..45d052c 100644
--- a/encoding/protojson/encode.go
+++ b/encoding/protojson/encode.go
@@ -88,20 +88,17 @@
 // marshalFields marshals the fields in the given protoreflect.Message.
 func (o MarshalOptions) marshalFields(m pref.Message) error {
 	var nerr errors.NonFatal
-	fieldDescs := m.Descriptor().Fields()
-	knownFields := m.KnownFields()
 
 	// Marshal out known fields.
+	fieldDescs := m.Descriptor().Fields()
 	for i := 0; i < fieldDescs.Len(); i++ {
 		fd := fieldDescs.Get(i)
-		num := fd.Number()
-
-		if !knownFields.Has(num) {
+		if !m.Has(fd) {
 			continue
 		}
 
 		name := fd.JSONName()
-		val := knownFields.Get(num)
+		val := m.Get(fd)
 		if err := o.encoder.WriteName(name); !nerr.Merge(err) {
 			return err
 		}
@@ -111,7 +108,7 @@
 	}
 
 	// Marshal out extensions.
-	if err := o.marshalExtensions(knownFields); !nerr.Merge(err) {
+	if err := o.marshalExtensions(m); !nerr.Merge(err) {
 		return err
 	}
 	return nerr.E
@@ -254,34 +251,33 @@
 }
 
 // marshalExtensions marshals extension fields.
-func (o MarshalOptions) marshalExtensions(knownFields pref.KnownFields) error {
-	type xtEntry struct {
-		key    string
-		value  pref.Value
-		xtType pref.ExtensionType
+func (o MarshalOptions) marshalExtensions(m pref.Message) error {
+	type entry struct {
+		key   string
+		value pref.Value
+		desc  pref.FieldDescriptor
 	}
 
-	xtTypes := knownFields.ExtensionTypes()
-
 	// Get a sorted list based on field key first.
-	entries := make([]xtEntry, 0, xtTypes.Len())
-	xtTypes.Range(func(xt pref.ExtensionType) bool {
-		name := xt.Descriptor().FullName()
+	var entries []entry
+	m.Range(func(fd pref.FieldDescriptor, v pref.Value) bool {
+		if !fd.IsExtension() {
+			return true
+		}
+		xt := fd.(pref.ExtensionType)
+
 		// If extended type is a MessageSet, set field name to be the message type name.
+		name := xt.Descriptor().FullName()
 		if isMessageSetExtension(xt) {
 			name = xt.Descriptor().Message().FullName()
 		}
 
-		num := xt.Descriptor().Number()
-		if knownFields.Has(num) {
-			// Use [name] format for JSON field name.
-			pval := knownFields.Get(num)
-			entries = append(entries, xtEntry{
-				key:    string(name),
-				value:  pval,
-				xtType: xt,
-			})
-		}
+		// Use [name] format for JSON field name.
+		entries = append(entries, entry{
+			key:   string(name),
+			value: v,
+			desc:  fd,
+		})
 		return true
 	})
 
@@ -299,7 +295,7 @@
 		if err := o.encoder.WriteName("[" + entry.key + "]"); !nerr.Merge(err) {
 			return err
 		}
-		if err := o.marshalValue(entry.value, entry.xtType.Descriptor()); !nerr.Merge(err) {
+		if err := o.marshalValue(entry.value, entry.desc); !nerr.Merge(err) {
 			return err
 		}
 	}
diff --git a/encoding/protojson/encode_test.go b/encoding/protojson/encode_test.go
index 3483aed..691fcf7 100644
--- a/encoding/protojson/encode_test.go
+++ b/encoding/protojson/encode_test.go
@@ -15,7 +15,6 @@
 	"github.com/google/go-cmp/cmp/cmpopts"
 	"google.golang.org/protobuf/encoding/protojson"
 	"google.golang.org/protobuf/internal/encoding/pack"
-	"google.golang.org/protobuf/internal/encoding/wire"
 	pimpl "google.golang.org/protobuf/internal/impl"
 	"google.golang.org/protobuf/internal/scalar"
 	"google.golang.org/protobuf/proto"
@@ -50,15 +49,9 @@
 	return p
 }
 
+// TODO: Replace this with proto.SetExtension.
 func setExtension(m proto.Message, xd *protoiface.ExtensionDescV1, val interface{}) {
-	knownFields := m.ProtoReflect().KnownFields()
-	extTypes := knownFields.ExtensionTypes()
-	extTypes.Register(xd.Type)
-	if val == nil {
-		return
-	}
-	pval := xd.Type.ValueOf(val)
-	knownFields.Set(wire.Number(xd.Field), pval)
+	m.ProtoReflect().Set(xd.Type, xd.Type.ValueOf(val))
 }
 
 // dhex decodes a hex-string and returns the bytes and panics if s is invalid.
@@ -945,14 +938,6 @@
   "[pb2.opt_ext_string]": "extension field"
 }`,
 	}, {
-		desc: "extension message field set to nil",
-		input: func() proto.Message {
-			m := &pb2.Extensions{}
-			setExtension(m, pb2.E_OptExtNested, nil)
-			return m
-		}(),
-		want: "{}",
-	}, {
 		desc: "extensions of repeated fields",
 		input: func() proto.Message {
 			m := &pb2.Extensions{}
diff --git a/encoding/protojson/well_known_types.go b/encoding/protojson/well_known_types.go
index 6aedebd..8f1205c 100644
--- a/encoding/protojson/well_known_types.go
+++ b/encoding/protojson/well_known_types.go
@@ -142,25 +142,26 @@
 // field `value` which holds the custom JSON in addition to the `@type` field.
 
 func (o MarshalOptions) marshalAny(m pref.Message) error {
-	messageDesc := m.Descriptor()
-	knownFields := m.KnownFields()
+	fds := m.Descriptor().Fields()
+	fdType := fds.ByNumber(fieldnum.Any_TypeUrl)
+	fdValue := fds.ByNumber(fieldnum.Any_Value)
 
 	// Start writing the JSON object.
 	o.encoder.StartObject()
 	defer o.encoder.EndObject()
 
-	if !knownFields.Has(fieldnum.Any_TypeUrl) {
-		if !knownFields.Has(fieldnum.Any_Value) {
+	if !m.Has(fdType) {
+		if !m.Has(fdValue) {
 			// If message is empty, marshal out empty JSON object.
 			return nil
 		} else {
 			// Return error if type_url field is not set, but value is set.
-			return errors.New("%s: type_url is not set", messageDesc.FullName())
+			return errors.New("%s: type_url is not set", m.Descriptor().FullName())
 		}
 	}
 
-	typeVal := knownFields.Get(fieldnum.Any_TypeUrl)
-	valueVal := knownFields.Get(fieldnum.Any_Value)
+	typeVal := m.Get(fdType)
+	valueVal := m.Get(fdValue)
 
 	// Marshal out @type field.
 	typeURL := typeVal.String()
@@ -173,7 +174,7 @@
 	// Resolve the type in order to unmarshal value field.
 	emt, err := o.Resolver.FindMessageByURL(typeURL)
 	if !nerr.Merge(err) {
-		return errors.New("%s: unable to resolve %q: %v", messageDesc.FullName(), typeURL, err)
+		return errors.New("%s: unable to resolve %q: %v", m.Descriptor().FullName(), typeURL, err)
 	}
 
 	em := emt.New()
@@ -185,7 +186,7 @@
 		AllowPartial: o.AllowPartial,
 	}.Unmarshal(valueVal.Bytes(), em.Interface())
 	if !nerr.Merge(err) {
-		return errors.New("%s: unable to unmarshal %q: %v", messageDesc.FullName(), typeURL, err)
+		return errors.New("%s: unable to unmarshal %q: %v", m.Descriptor().FullName(), typeURL, err)
 	}
 
 	// If type of value has custom JSON encoding, marshal out a field "value"
@@ -263,9 +264,12 @@
 		return errors.New("google.protobuf.Any: %v", err)
 	}
 
-	knownFields := m.KnownFields()
-	knownFields.Set(fieldnum.Any_TypeUrl, pref.ValueOf(typeURL))
-	knownFields.Set(fieldnum.Any_Value, pref.ValueOf(b))
+	fds := m.Descriptor().Fields()
+	fdType := fds.ByNumber(fieldnum.Any_TypeUrl)
+	fdValue := fds.ByNumber(fieldnum.Any_Value)
+
+	m.Set(fdType, pref.ValueOf(typeURL))
+	m.Set(fdValue, pref.ValueOf(b))
 	return nerr.E
 }
 
@@ -446,7 +450,7 @@
 
 func (o MarshalOptions) marshalWrapperType(m pref.Message) error {
 	fd := m.Descriptor().Fields().ByNumber(wrapperFieldNumber)
-	val := m.KnownFields().Get(wrapperFieldNumber)
+	val := m.Get(fd)
 	return o.marshalSingular(val, fd)
 }
 
@@ -457,7 +461,7 @@
 	if !nerr.Merge(err) {
 		return err
 	}
-	m.KnownFields().Set(wrapperFieldNumber, val)
+	m.Set(fd, val)
 	return nerr.E
 }
 
@@ -509,14 +513,12 @@
 
 func (o MarshalOptions) marshalStruct(m pref.Message) error {
 	fd := m.Descriptor().Fields().ByNumber(fieldnum.Struct_Fields)
-	val := m.KnownFields().Get(fieldnum.Struct_Fields)
-	return o.marshalMap(val.Map(), fd)
+	return o.marshalMap(m.Get(fd).Map(), fd)
 }
 
 func (o UnmarshalOptions) unmarshalStruct(m pref.Message) error {
 	fd := m.Descriptor().Fields().ByNumber(fieldnum.Struct_Fields)
-	val := m.KnownFields().Get(fieldnum.Struct_Fields)
-	return o.unmarshalMap(val.Map(), fd)
+	return o.unmarshalMap(m.Mutable(fd).Map(), fd)
 }
 
 // The JSON representation for ListValue is JSON array that contains the encoded
@@ -525,14 +527,12 @@
 
 func (o MarshalOptions) marshalListValue(m pref.Message) error {
 	fd := m.Descriptor().Fields().ByNumber(fieldnum.ListValue_Values)
-	val := m.KnownFields().Get(fieldnum.ListValue_Values)
-	return o.marshalList(val.List(), fd)
+	return o.marshalList(m.Get(fd).List(), fd)
 }
 
 func (o UnmarshalOptions) unmarshalListValue(m pref.Message) error {
 	fd := m.Descriptor().Fields().ByNumber(fieldnum.ListValue_Values)
-	val := m.KnownFields().Get(fieldnum.ListValue_Values)
-	return o.unmarshalList(val.List(), fd)
+	return o.unmarshalList(m.Mutable(fd).List(), fd)
 }
 
 // The JSON representation for a Value is dependent on the oneof field that is
@@ -540,27 +540,21 @@
 // Value message needs to be a oneof field set, else it is an error.
 
 func (o MarshalOptions) marshalKnownValue(m pref.Message) error {
-	messageDesc := m.Descriptor()
-	knownFields := m.KnownFields()
-	num := knownFields.WhichOneof("kind")
-	if num == 0 {
-		// Return error if none of the fields is set.
-		return errors.New("%s: none of the oneof fields is set", messageDesc.FullName())
+	od := m.Descriptor().Oneofs().ByName("kind")
+	fd := m.WhichOneof(od)
+	if fd == nil {
+		return errors.New("%s: none of the oneof fields is set", m.Descriptor().FullName())
 	}
-
-	fd := messageDesc.Fields().ByNumber(num)
-	val := knownFields.Get(num)
-	return o.marshalSingular(val, fd)
+	return o.marshalSingular(m.Get(fd), fd)
 }
 
 func (o UnmarshalOptions) unmarshalKnownValue(m pref.Message) error {
 	var nerr errors.NonFatal
-	knownFields := m.KnownFields()
-
 	switch o.decoder.Peek() {
 	case json.Null:
 		o.decoder.Read()
-		knownFields.Set(fieldnum.Value_NullValue, pref.ValueOf(pref.EnumNumber(0)))
+		fd := m.Descriptor().Fields().ByNumber(fieldnum.Value_NullValue)
+		m.Set(fd, pref.ValueOf(pref.EnumNumber(0)))
 
 	case json.Bool:
 		jval, err := o.decoder.Read()
@@ -571,7 +565,8 @@
 		if err != nil {
 			return err
 		}
-		knownFields.Set(fieldnum.Value_BoolValue, val)
+		fd := m.Descriptor().Fields().ByNumber(fieldnum.Value_BoolValue)
+		m.Set(fd, val)
 
 	case json.Number:
 		jval, err := o.decoder.Read()
@@ -582,7 +577,8 @@
 		if err != nil {
 			return err
 		}
-		knownFields.Set(fieldnum.Value_NumberValue, val)
+		fd := m.Descriptor().Fields().ByNumber(fieldnum.Value_NumberValue)
+		m.Set(fd, val)
 
 	case json.String:
 		// A JSON string may have been encoded from the number_value field,
@@ -599,21 +595,24 @@
 		if !nerr.Merge(err) {
 			return err
 		}
-		knownFields.Set(fieldnum.Value_StringValue, val)
+		fd := m.Descriptor().Fields().ByNumber(fieldnum.Value_StringValue)
+		m.Set(fd, val)
 
 	case json.StartObject:
-		m := knownFields.NewMessage(fieldnum.Value_StructValue)
-		if err := o.unmarshalStruct(m); !nerr.Merge(err) {
+		fd := m.Descriptor().Fields().ByNumber(fieldnum.Value_StructValue)
+		m2 := m.NewMessage(fd)
+		if err := o.unmarshalStruct(m2); !nerr.Merge(err) {
 			return err
 		}
-		knownFields.Set(fieldnum.Value_StructValue, pref.ValueOf(m))
+		m.Set(fd, pref.ValueOf(m2))
 
 	case json.StartArray:
-		m := knownFields.NewMessage(fieldnum.Value_ListValue)
-		if err := o.unmarshalListValue(m); !nerr.Merge(err) {
+		fd := m.Descriptor().Fields().ByNumber(fieldnum.Value_ListValue)
+		m2 := m.NewMessage(fd)
+		if err := o.unmarshalListValue(m2); !nerr.Merge(err) {
 			return err
 		}
-		knownFields.Set(fieldnum.Value_ListValue, pref.ValueOf(m))
+		m.Set(fd, pref.ValueOf(m2))
 
 	default:
 		jval, err := o.decoder.Read()
@@ -622,7 +621,6 @@
 		}
 		return unexpectedJSONError{jval}
 	}
-
 	return nerr.E
 }
 
@@ -644,21 +642,22 @@
 )
 
 func (o MarshalOptions) marshalDuration(m pref.Message) error {
-	messageDesc := m.Descriptor()
-	knownFields := m.KnownFields()
+	fds := m.Descriptor().Fields()
+	fdSeconds := fds.ByNumber(fieldnum.Duration_Seconds)
+	fdNanos := fds.ByNumber(fieldnum.Duration_Nanos)
 
-	secsVal := knownFields.Get(fieldnum.Duration_Seconds)
-	nanosVal := knownFields.Get(fieldnum.Duration_Nanos)
+	secsVal := m.Get(fdSeconds)
+	nanosVal := m.Get(fdNanos)
 	secs := secsVal.Int()
 	nanos := nanosVal.Int()
 	if secs < -maxSecondsInDuration || secs > maxSecondsInDuration {
-		return errors.New("%s: seconds out of range %v", messageDesc.FullName(), secs)
+		return errors.New("%s: seconds out of range %v", m.Descriptor().FullName(), secs)
 	}
 	if nanos < -secondsInNanos || nanos > secondsInNanos {
-		return errors.New("%s: nanos out of range %v", messageDesc.FullName(), nanos)
+		return errors.New("%s: nanos out of range %v", m.Descriptor().FullName(), nanos)
 	}
 	if (secs > 0 && nanos < 0) || (secs < 0 && nanos > 0) {
-		return errors.New("%s: signs of seconds and nanos do not match", messageDesc.FullName())
+		return errors.New("%s: signs of seconds and nanos do not match", m.Descriptor().FullName())
 	}
 	// Generated output always contains 0, 3, 6, or 9 fractional digits,
 	// depending on required precision, followed by the suffix "s".
@@ -687,21 +686,23 @@
 		return unexpectedJSONError{jval}
 	}
 
-	messageDesc := m.Descriptor()
 	input := jval.String()
 	secs, nanos, ok := parseDuration(input)
 	if !ok {
-		return errors.New("%s: invalid duration value %q", messageDesc.FullName(), input)
+		return errors.New("%s: invalid duration value %q", m.Descriptor().FullName(), input)
 	}
 	// Validate seconds. No need to validate nanos because parseDuration would
 	// have covered that already.
 	if secs < -maxSecondsInDuration || secs > maxSecondsInDuration {
-		return errors.New("%s: out of range %q", messageDesc.FullName(), input)
+		return errors.New("%s: out of range %q", m.Descriptor().FullName(), input)
 	}
 
-	knownFields := m.KnownFields()
-	knownFields.Set(fieldnum.Duration_Seconds, pref.ValueOf(secs))
-	knownFields.Set(fieldnum.Duration_Nanos, pref.ValueOf(nanos))
+	fds := m.Descriptor().Fields()
+	fdSeconds := fds.ByNumber(fieldnum.Duration_Seconds)
+	fdNanos := fds.ByNumber(fieldnum.Duration_Nanos)
+
+	m.Set(fdSeconds, pref.ValueOf(secs))
+	m.Set(fdNanos, pref.ValueOf(nanos))
 	return nerr.E
 }
 
@@ -834,18 +835,19 @@
 )
 
 func (o MarshalOptions) marshalTimestamp(m pref.Message) error {
-	messageDesc := m.Descriptor()
-	knownFields := m.KnownFields()
+	fds := m.Descriptor().Fields()
+	fdSeconds := fds.ByNumber(fieldnum.Timestamp_Seconds)
+	fdNanos := fds.ByNumber(fieldnum.Timestamp_Nanos)
 
-	secsVal := knownFields.Get(fieldnum.Timestamp_Seconds)
-	nanosVal := knownFields.Get(fieldnum.Timestamp_Nanos)
+	secsVal := m.Get(fdSeconds)
+	nanosVal := m.Get(fdNanos)
 	secs := secsVal.Int()
 	nanos := nanosVal.Int()
 	if secs < minTimestampSeconds || secs > maxTimestampSeconds {
-		return errors.New("%s: seconds out of range %v", messageDesc.FullName(), secs)
+		return errors.New("%s: seconds out of range %v", m.Descriptor().FullName(), secs)
 	}
 	if nanos < 0 || nanos > secondsInNanos {
-		return errors.New("%s: nanos out of range %v", messageDesc.FullName(), nanos)
+		return errors.New("%s: nanos out of range %v", m.Descriptor().FullName(), nanos)
 	}
 	// Uses RFC 3339, where generated output will be Z-normalized and uses 0, 3,
 	// 6 or 9 fractional digits.
@@ -868,22 +870,24 @@
 		return unexpectedJSONError{jval}
 	}
 
-	messageDesc := m.Descriptor()
 	input := jval.String()
 	t, err := time.Parse(time.RFC3339Nano, input)
 	if err != nil {
-		return errors.New("%s: invalid timestamp value %q", messageDesc.FullName(), input)
+		return errors.New("%s: invalid timestamp value %q", m.Descriptor().FullName(), input)
 	}
 	// Validate seconds. No need to validate nanos because time.Parse would have
 	// covered that already.
 	secs := t.Unix()
 	if secs < minTimestampSeconds || secs > maxTimestampSeconds {
-		return errors.New("%s: out of range %q", messageDesc.FullName(), input)
+		return errors.New("%s: out of range %q", m.Descriptor().FullName(), input)
 	}
 
-	knownFields := m.KnownFields()
-	knownFields.Set(fieldnum.Timestamp_Seconds, pref.ValueOf(secs))
-	knownFields.Set(fieldnum.Timestamp_Nanos, pref.ValueOf(int32(t.Nanosecond())))
+	fds := m.Descriptor().Fields()
+	fdSeconds := fds.ByNumber(fieldnum.Timestamp_Seconds)
+	fdNanos := fds.ByNumber(fieldnum.Timestamp_Nanos)
+
+	m.Set(fdSeconds, pref.ValueOf(secs))
+	m.Set(fdNanos, pref.ValueOf(int32(t.Nanosecond())))
 	return nerr.E
 }
 
@@ -893,8 +897,8 @@
 // end up differently after a round-trip.
 
 func (o MarshalOptions) marshalFieldMask(m pref.Message) error {
-	val := m.KnownFields().Get(fieldnum.FieldMask_Paths)
-	list := val.List()
+	fd := m.Descriptor().Fields().ByNumber(fieldnum.FieldMask_Paths)
+	list := m.Get(fd).List()
 	paths := make([]string, 0, list.Len())
 
 	for i := 0; i < list.Len(); i++ {
@@ -926,8 +930,8 @@
 	}
 	paths := strings.Split(str, ",")
 
-	val := m.KnownFields().Get(fieldnum.FieldMask_Paths)
-	list := val.List()
+	fd := m.Descriptor().Fields().ByNumber(fieldnum.FieldMask_Paths)
+	list := m.Mutable(fd).List()
 
 	for _, s := range paths {
 		s = strings.TrimSpace(s)
diff --git a/encoding/prototext/decode.go b/encoding/prototext/decode.go
index 20bdfe6..4444de6 100644
--- a/encoding/prototext/decode.go
+++ b/encoding/prototext/decode.go
@@ -47,12 +47,11 @@
 func (o UnmarshalOptions) Unmarshal(b []byte, m proto.Message) error {
 	var nerr errors.NonFatal
 
-	mr := m.ProtoReflect()
 	// Clear all fields before populating it.
 	// TODO: Determine if this needs to be consistent with protojson and binary unmarshal where
 	// behavior is to merge values into existing message. If decision is to not clear the fields
 	// ahead, code will need to be updated properly when merging nested messages.
-	resetMessage(mr)
+	proto.Reset(m)
 
 	// Parse into text.Value of message type.
 	val, err := text.Unmarshal(b)
@@ -63,7 +62,7 @@
 	if o.Resolver == nil {
 		o.Resolver = protoregistry.GlobalTypes
 	}
-	err = o.unmarshalMessage(val.Message(), mr)
+	err = o.unmarshalMessage(val.Message(), m.ProtoReflect())
 	if !nerr.Merge(err) {
 		return err
 	}
@@ -75,41 +74,19 @@
 	return nerr.E
 }
 
-// resetMessage clears all fields of given protoreflect.Message.
-// TODO: This should go into the proto package.
-func resetMessage(m pref.Message) {
-	knownFields := m.KnownFields()
-	knownFields.Range(func(num pref.FieldNumber, _ pref.Value) bool {
-		knownFields.Clear(num)
-		return true
-	})
-	unknownFields := m.UnknownFields()
-	unknownFields.Range(func(num pref.FieldNumber, _ pref.RawFields) bool {
-		unknownFields.Set(num, nil)
-		return true
-	})
-	extTypes := knownFields.ExtensionTypes()
-	extTypes.Range(func(xt pref.ExtensionType) bool {
-		extTypes.Remove(xt)
-		return true
-	})
-}
-
 // unmarshalMessage unmarshals a [][2]text.Value message into the given protoreflect.Message.
 func (o UnmarshalOptions) unmarshalMessage(tmsg [][2]text.Value, m pref.Message) error {
 	var nerr errors.NonFatal
 
 	messageDesc := m.Descriptor()
-	knownFields := m.KnownFields()
 
 	// Handle expanded Any message.
 	if messageDesc.FullName() == "google.protobuf.Any" && isExpandedAny(tmsg) {
-		return o.unmarshalAny(tmsg[0], knownFields)
+		return o.unmarshalAny(tmsg[0], m)
 	}
 
 	fieldDescs := messageDesc.Fields()
 	reservedNames := messageDesc.ReservedNames()
-	xtTypes := knownFields.ExtensionTypes()
 	var seenNums set.Ints
 	var seenOneofs set.Ints
 
@@ -134,23 +111,14 @@
 			}
 			// Extensions have to be registered first in the message's
 			// ExtensionTypes before setting a value to it.
-			xtName := pref.FullName(tkey.String())
+			extName := pref.FullName(tkey.String())
 			// Check first if it is already registered. This is the case for
 			// repeated fields.
-			xt := xtTypes.ByName(xtName)
-			if xt == nil {
-				var err error
-				xt, err = o.findExtension(xtName)
-				if err != nil && err != protoregistry.NotFound {
-					return errors.New("unable to resolve [%v]: %v", xtName, err)
-				}
-				if xt != nil {
-					xtTypes.Register(xt)
-				}
+			xt, err := o.findExtension(extName)
+			if err != nil && err != protoregistry.NotFound {
+				return errors.New("unable to resolve [%v]: %v", extName, err)
 			}
-			if xt != nil {
-				fd = xt.Descriptor()
-			}
+			fd = xt
 		}
 
 		if fd == nil {
@@ -172,7 +140,7 @@
 				items = tval.List()
 			}
 
-			list := knownFields.Get(fd.Number()).List()
+			list := m.Mutable(fd).List()
 			if err := o.unmarshalList(items, fd, list); !nerr.Merge(err) {
 				return err
 			}
@@ -185,7 +153,7 @@
 				items = tval.List()
 			}
 
-			mmap := knownFields.Get(fd.Number()).Map()
+			mmap := m.Mutable(fd).Map()
 			if err := o.unmarshalMap(items, fd, mmap); !nerr.Merge(err) {
 				return err
 			}
@@ -204,7 +172,7 @@
 			if seenNums.Has(num) {
 				return errors.New("non-repeated field %v is repeated", fd.FullName())
 			}
-			if err := o.unmarshalSingular(tval, fd, knownFields); !nerr.Merge(err) {
+			if err := o.unmarshalSingular(tval, fd, m); !nerr.Merge(err) {
 				return err
 			}
 			seenNums.Set(num)
@@ -230,9 +198,7 @@
 }
 
 // unmarshalSingular unmarshals given text.Value into the non-repeated field.
-func (o UnmarshalOptions) unmarshalSingular(input text.Value, fd pref.FieldDescriptor, knownFields pref.KnownFields) error {
-	num := fd.Number()
-
+func (o UnmarshalOptions) unmarshalSingular(input text.Value, fd pref.FieldDescriptor, m pref.Message) error {
 	var nerr errors.NonFatal
 	var val pref.Value
 	switch fd.Kind() {
@@ -240,11 +206,11 @@
 		if input.Type() != text.Message {
 			return errors.New("%v contains invalid message/group value: %v", fd.FullName(), input)
 		}
-		m := knownFields.NewMessage(num)
-		if err := o.unmarshalMessage(input.Message(), m); !nerr.Merge(err) {
+		m2 := m.NewMessage(fd)
+		if err := o.unmarshalMessage(input.Message(), m2); !nerr.Merge(err) {
 			return err
 		}
-		val = pref.ValueOf(m)
+		val = pref.ValueOf(m2)
 	default:
 		var err error
 		val, err = unmarshalScalar(input, fd)
@@ -252,7 +218,7 @@
 			return err
 		}
 	}
-	knownFields.Set(num, val)
+	m.Set(fd, val)
 
 	return nerr.E
 }
@@ -480,7 +446,7 @@
 
 // unmarshalAny unmarshals an expanded Any textproto. This method assumes that the given
 // tfield has key type of text.String and value type of text.Message.
-func (o UnmarshalOptions) unmarshalAny(tfield [2]text.Value, knownFields pref.KnownFields) error {
+func (o UnmarshalOptions) unmarshalAny(tfield [2]text.Value, m pref.Message) error {
 	var nerr errors.NonFatal
 
 	typeURL := tfield[0].String()
@@ -492,8 +458,8 @@
 	}
 	// Create new message for the embedded message type and unmarshal the
 	// value into it.
-	m := mt.New()
-	if err := o.unmarshalMessage(value, m); !nerr.Merge(err) {
+	m2 := mt.New()
+	if err := o.unmarshalMessage(value, m2); !nerr.Merge(err) {
 		return err
 	}
 	// Serialize the embedded message and assign the resulting bytes to the value field.
@@ -503,13 +469,17 @@
 	b, err := proto.MarshalOptions{
 		AllowPartial:  o.AllowPartial,
 		Deterministic: true,
-	}.Marshal(m.Interface())
+	}.Marshal(m2.Interface())
 	if !nerr.Merge(err) {
 		return err
 	}
 
-	knownFields.Set(fieldnum.Any_TypeUrl, pref.ValueOf(typeURL))
-	knownFields.Set(fieldnum.Any_Value, pref.ValueOf(b))
+	fds := m.Descriptor().Fields()
+	fdType := fds.ByNumber(fieldnum.Any_TypeUrl)
+	fdValue := fds.ByNumber(fieldnum.Any_Value)
+
+	m.Set(fdType, pref.ValueOf(typeURL))
+	m.Set(fdValue, pref.ValueOf(b))
 
 	return nerr.E
 }
diff --git a/encoding/prototext/encode.go b/encoding/prototext/encode.go
index d86492e..c3a91ca 100644
--- a/encoding/prototext/encode.go
+++ b/encoding/prototext/encode.go
@@ -88,13 +88,10 @@
 
 	// Handle known fields.
 	fieldDescs := messageDesc.Fields()
-	knownFields := m.KnownFields()
 	size := fieldDescs.Len()
 	for i := 0; i < size; i++ {
 		fd := fieldDescs.Get(i)
-		num := fd.Number()
-
-		if !knownFields.Has(num) {
+		if !m.Has(fd) {
 			continue
 		}
 
@@ -103,7 +100,7 @@
 		if fd.Kind() == pref.GroupKind {
 			name = text.ValueOf(fd.Message().Name())
 		}
-		pval := knownFields.Get(num)
+		pval := m.Get(fd)
 		var err error
 		msgFields, err = o.appendField(msgFields, name, pval, fd)
 		if !nerr.Merge(err) {
@@ -113,17 +110,14 @@
 
 	// Handle extensions.
 	var err error
-	msgFields, err = o.appendExtensions(msgFields, knownFields)
+	msgFields, err = o.appendExtensions(msgFields, m)
 	if !nerr.Merge(err) {
 		return text.Value{}, err
 	}
 
 	// Handle unknown fields.
 	// TODO: Provide option to exclude or include unknown fields.
-	m.UnknownFields().Range(func(_ pref.FieldNumber, raw pref.RawFields) bool {
-		msgFields = appendUnknown(msgFields, raw)
-		return true
-	})
+	msgFields = appendUnknown(msgFields, m.GetUnknown())
 
 	return text.ValueOf(msgFields), nerr.E
 }
@@ -259,30 +253,29 @@
 }
 
 // appendExtensions marshals extension fields and appends them to the given [][2]text.Value.
-func (o MarshalOptions) appendExtensions(msgFields [][2]text.Value, knownFields pref.KnownFields) ([][2]text.Value, error) {
-	xtTypes := knownFields.ExtensionTypes()
-	xtFields := make([][2]text.Value, 0, xtTypes.Len())
-
+func (o MarshalOptions) appendExtensions(msgFields [][2]text.Value, m pref.Message) ([][2]text.Value, error) {
 	var nerr errors.NonFatal
 	var err error
-	xtTypes.Range(func(xt pref.ExtensionType) bool {
-		name := xt.Descriptor().FullName()
+	var entries [][2]text.Value
+	m.Range(func(fd pref.FieldDescriptor, v pref.Value) bool {
+		if !fd.IsExtension() {
+			return true
+		}
+		xt := fd.(pref.ExtensionType)
+
 		// If extended type is a MessageSet, set field name to be the message type name.
+		name := xt.Descriptor().FullName()
 		if isMessageSetExtension(xt) {
 			name = xt.Descriptor().Message().FullName()
 		}
 
-		num := xt.Descriptor().Number()
-		if knownFields.Has(num) {
-			// Use string type to produce [name] format.
-			tname := text.ValueOf(string(name))
-			pval := knownFields.Get(num)
-			xtFields, err = o.appendField(xtFields, tname, pval, xt.Descriptor())
-			if !nerr.Merge(err) {
-				return false
-			}
-			err = nil
+		// Use string type to produce [name] format.
+		tname := text.ValueOf(string(name))
+		entries, err = o.appendField(entries, tname, v, xt)
+		if !nerr.Merge(err) {
+			return false
 		}
+		err = nil
 		return true
 	})
 	if err != nil {
@@ -290,10 +283,10 @@
 	}
 
 	// Sort extensions lexicographically and append to output.
-	sort.SliceStable(xtFields, func(i, j int) bool {
-		return xtFields[i][0].String() < xtFields[j][0].String()
+	sort.SliceStable(entries, func(i, j int) bool {
+		return entries[i][0].String() < entries[j][0].String()
 	})
-	return append(msgFields, xtFields...), nerr.E
+	return append(msgFields, entries...), nerr.E
 }
 
 // isMessageSetExtension reports whether extension extends a message set.
@@ -347,11 +340,14 @@
 
 // marshalAny converts a google.protobuf.Any protoreflect.Message to a text.Value.
 func (o MarshalOptions) marshalAny(m pref.Message) (text.Value, error) {
-	var nerr errors.NonFatal
-	knownFields := m.KnownFields()
-	typeURL := knownFields.Get(fieldnum.Any_TypeUrl).String()
-	value := knownFields.Get(fieldnum.Any_Value)
+	fds := m.Descriptor().Fields()
+	fdType := fds.ByNumber(fieldnum.Any_TypeUrl)
+	fdValue := fds.ByNumber(fieldnum.Any_Value)
 
+	typeURL := m.Get(fdType).String()
+	value := m.Get(fdValue)
+
+	var nerr errors.NonFatal
 	emt, err := o.Resolver.FindMessageByURL(typeURL)
 	if !nerr.Merge(err) {
 		return text.Value{}, err
diff --git a/encoding/prototext/encode_test.go b/encoding/prototext/encode_test.go
index 98732bc..cece731 100644
--- a/encoding/prototext/encode_test.go
+++ b/encoding/prototext/encode_test.go
@@ -8,15 +8,12 @@
 	"bytes"
 	"encoding/hex"
 	"math"
-	"strings"
 	"testing"
 
 	"github.com/google/go-cmp/cmp"
-	"github.com/google/go-cmp/cmp/cmpopts"
 	"google.golang.org/protobuf/encoding/prototext"
 	"google.golang.org/protobuf/internal/detrand"
 	"google.golang.org/protobuf/internal/encoding/pack"
-	"google.golang.org/protobuf/internal/encoding/wire"
 	pimpl "google.golang.org/protobuf/internal/impl"
 	"google.golang.org/protobuf/internal/scalar"
 	"google.golang.org/protobuf/proto"
@@ -33,11 +30,6 @@
 	detrand.Disable()
 }
 
-// splitLines is a cmpopts.Option for comparing strings with line breaks.
-var splitLines = cmpopts.AcyclicTransformer("SplitLines", func(s string) []string {
-	return strings.Split(s, "\n")
-})
-
 func pb2Enum(i int32) *pb2.Enum {
 	p := new(pb2.Enum)
 	*p = pb2.Enum(i)
@@ -50,15 +42,9 @@
 	return p
 }
 
+// TODO: Use proto.SetExtension when available.
 func setExtension(m proto.Message, xd *protoiface.ExtensionDescV1, val interface{}) {
-	knownFields := m.ProtoReflect().KnownFields()
-	extTypes := knownFields.ExtensionTypes()
-	extTypes.Register(xd.Type)
-	if val == nil {
-		return
-	}
-	pval := xd.Type.ValueOf(val)
-	knownFields.Set(wire.Number(xd.Field), pval)
+	m.ProtoReflect().Set(xd.Type, xd.Type.ValueOf(val))
 }
 
 // dhex decodes a hex-string and returns the bytes and panics if s is invalid.
@@ -938,8 +924,8 @@
 			}.Marshal(),
 		},
 		want: `101: "\x01\x00\x01"
-101: 1
 102: "hello"
+101: 1
 102: "世界"
 `,
 	}, {
@@ -1019,14 +1005,6 @@
 }
 `,
 	}, {
-		desc: "extension message field set to nil",
-		input: func() proto.Message {
-			m := &pb2.Extensions{}
-			setExtension(m, pb2.E_OptExtNested, nil)
-			return m
-		}(),
-		want: "\n",
-	}, {
 		desc: "extensions of repeated fields",
 		input: func() proto.Message {
 			m := &pb2.Extensions{}
@@ -1295,7 +1273,7 @@
 			got := string(b)
 			if tt.want != "" && got != tt.want {
 				t.Errorf("Marshal()\n<got>\n%v\n<want>\n%v\n", got, tt.want)
-				if diff := cmp.Diff(tt.want, got, splitLines); diff != "" {
+				if diff := cmp.Diff(tt.want, got); diff != "" {
 					t.Errorf("Marshal() diff -want +got\n%v\n", diff)
 				}
 			}