aboutsummaryrefslogtreecommitdiff
path: root/unions.go
blob: a9f624c73819455d1c73acf2e3efd4b04b42a347 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
package bare

import (
	"fmt"
	"reflect"
)

// Any type which is a union member must implement this interface. You must
// also call RegisterUnion for go-bare to marshal or unmarshal messages which
// utilize your union type.
type Union interface {
	IsUnion()
}

type UnionTags struct {
	iface reflect.Type
	tags  map[reflect.Type]uint64
	types map[uint64]reflect.Type
}

var unionInterface = reflect.TypeOf((*Union)(nil)).Elem()
var unionRegistry map[reflect.Type]*UnionTags

func init() {
	unionRegistry = make(map[reflect.Type]*UnionTags)
}

// Registers a union type in this context. Pass the union interface and the
// list of types associated with it, sorted ascending by their union tag.
func RegisterUnion(iface interface{}) *UnionTags {
	ity := reflect.TypeOf(iface).Elem()
	if _, ok := unionRegistry[ity]; ok {
		panic(fmt.Errorf("Type %s has already been registered", ity.Name()))
	}

	if !ity.Implements(reflect.TypeOf((*Union)(nil)).Elem()) {
		panic(fmt.Errorf("Type %s does not implement bare.Union", ity.Name()))
	}

	utypes := &UnionTags{
		iface: ity,
		tags:  make(map[reflect.Type]uint64),
		types: make(map[uint64]reflect.Type),
	}
	unionRegistry[ity] = utypes
	return utypes
}

func (ut *UnionTags) Member(t interface{}, tag uint64) *UnionTags {
	ty := reflect.TypeOf(t)
	if !ty.AssignableTo(ut.iface) {
		panic(fmt.Errorf("Type %s does not implement interface %s",
			ty.Name(), ut.iface.Name()))
	}
	if _, ok := ut.tags[ty]; ok {
		panic(fmt.Errorf("Type %s is already registered for union %s",
			ty.Name(), ut.iface.Name()))
	}
	if _, ok := ut.types[tag]; ok {
		panic(fmt.Errorf("Tag %d is already registered for union %s",
			tag, ut.iface.Name()))
	}
	ut.tags[ty] = tag
	ut.types[tag] = ty
	return ut
}

func (ut *UnionTags) TagFor(v interface{}) (uint64, bool) {
	tag, ok := ut.tags[reflect.TypeOf(v)]
	return tag, ok
}

func (ut *UnionTags) TypeFor(tag uint64) (reflect.Type, bool) {
	t, ok := ut.types[tag]
	return t, ok
}