ac: add gs_{prim,invocation}_id to the abi

Reviewed-by: Nicolai Hähnle <nicolai.haehnle@amd.com>
diff --git a/src/amd/common/ac_nir_to_llvm.c b/src/amd/common/ac_nir_to_llvm.c
index 3d9f613..1ecdeca 100644
--- a/src/amd/common/ac_nir_to_llvm.c
+++ b/src/amd/common/ac_nir_to_llvm.c
@@ -122,7 +122,6 @@
 	LLVMValueRef gs2vs_offset;
 	LLVMValueRef gs_wave_id;
 	LLVMValueRef gs_vtx_offset[6];
-	LLVMValueRef gs_prim_id, gs_invocation_id;
 
 	LLVMValueRef esgs_ring;
 	LLVMValueRef gsvs_ring;
@@ -826,8 +825,8 @@
 
 			add_vgpr_argument(&args, ctx->ac.i32, &ctx->gs_vtx_offset[0]); // vtx01
 			add_vgpr_argument(&args, ctx->ac.i32, &ctx->gs_vtx_offset[2]); // vtx23
-			add_vgpr_argument(&args, ctx->ac.i32, &ctx->gs_prim_id); // prim id
-			add_vgpr_argument(&args, ctx->ac.i32, &ctx->gs_invocation_id);
+			add_vgpr_argument(&args, ctx->ac.i32, &ctx->abi.gs_prim_id); // prim id
+			add_vgpr_argument(&args, ctx->ac.i32, &ctx->abi.gs_invocation_id);
 			add_vgpr_argument(&args, ctx->ac.i32, &ctx->gs_vtx_offset[4]);
 
 			if (previous_stage == MESA_SHADER_VERTEX) {
@@ -852,12 +851,12 @@
 			add_sgpr_argument(&args, ctx->ac.i32, &ctx->gs_wave_id); // wave id
 			add_vgpr_argument(&args, ctx->ac.i32, &ctx->gs_vtx_offset[0]); // vtx0
 			add_vgpr_argument(&args, ctx->ac.i32, &ctx->gs_vtx_offset[1]); // vtx1
-			add_vgpr_argument(&args, ctx->ac.i32, &ctx->gs_prim_id); // prim id
+			add_vgpr_argument(&args, ctx->ac.i32, &ctx->abi.gs_prim_id); // prim id
 			add_vgpr_argument(&args, ctx->ac.i32, &ctx->gs_vtx_offset[2]);
 			add_vgpr_argument(&args, ctx->ac.i32, &ctx->gs_vtx_offset[3]);
 			add_vgpr_argument(&args, ctx->ac.i32, &ctx->gs_vtx_offset[4]);
 			add_vgpr_argument(&args, ctx->ac.i32, &ctx->gs_vtx_offset[5]);
-			add_vgpr_argument(&args, ctx->ac.i32, &ctx->gs_invocation_id);
+			add_vgpr_argument(&args, ctx->ac.i32, &ctx->abi.gs_invocation_id);
 		}
 		break;
 	case MESA_SHADER_FRAGMENT:
@@ -4058,12 +4057,13 @@
 		if (ctx->stage == MESA_SHADER_TESS_CTRL)
 			result = unpack_param(&ctx->ac, ctx->nctx->tcs_rel_ids, 8, 5);
 		else
-			result = ctx->nctx->gs_invocation_id;
+			result = ctx->abi->gs_invocation_id;
 		break;
 	case nir_intrinsic_load_primitive_id:
 		if (ctx->stage == MESA_SHADER_GEOMETRY) {
-			ctx->nctx->shader_info->gs.uses_prim_id = true;
-			result = ctx->nctx->gs_prim_id;
+			if (ctx->nctx)
+				ctx->nctx->shader_info->gs.uses_prim_id = true;
+			result = ctx->abi->gs_prim_id;
 		} else if (ctx->stage == MESA_SHADER_TESS_CTRL) {
 			ctx->nctx->shader_info->tcs.uses_prim_id = true;
 			result = ctx->nctx->tcs_patch_id;
diff --git a/src/amd/common/ac_shader_abi.h b/src/amd/common/ac_shader_abi.h
index 27586d0..56209bd 100644
--- a/src/amd/common/ac_shader_abi.h
+++ b/src/amd/common/ac_shader_abi.h
@@ -42,6 +42,8 @@
 	LLVMValueRef draw_id;
 	LLVMValueRef vertex_id;
 	LLVMValueRef instance_id;
+	LLVMValueRef gs_prim_id;
+	LLVMValueRef gs_invocation_id;
 	LLVMValueRef frag_pos[4];
 	LLVMValueRef front_face;
 	LLVMValueRef ancillary;
diff --git a/src/gallium/drivers/radeonsi/si_shader.c b/src/gallium/drivers/radeonsi/si_shader.c
index 3293dd4..c1a3102 100644
--- a/src/gallium/drivers/radeonsi/si_shader.c
+++ b/src/gallium/drivers/radeonsi/si_shader.c
@@ -759,8 +759,7 @@
 		return LLVMGetParam(ctx->main_fn,
 				    ctx->param_tes_patch_id);
 	case PIPE_SHADER_GEOMETRY:
-		return LLVMGetParam(ctx->main_fn,
-				    ctx->param_gs_prim_id);
+		return ctx->abi.gs_prim_id;
 	default:
 		assert(0);
 		return ctx->i32_0;
@@ -1674,8 +1673,7 @@
 		if (ctx->type == PIPE_SHADER_TESS_CTRL)
 			value = unpack_param(ctx, ctx->param_tcs_rel_ids, 8, 5);
 		else if (ctx->type == PIPE_SHADER_GEOMETRY)
-			value = LLVMGetParam(ctx->main_fn,
-					     ctx->param_gs_instance_id);
+			value = ctx->abi.gs_invocation_id;
 		else
 			assert(!"INVOCATIONID not implemented");
 		break;
@@ -4562,8 +4560,8 @@
 		/* VGPRs (first GS, then VS/TES) */
 		ctx->param_gs_vtx01_offset = add_arg(&fninfo, ARG_VGPR, ctx->i32);
 		ctx->param_gs_vtx23_offset = add_arg(&fninfo, ARG_VGPR, ctx->i32);
-		ctx->param_gs_prim_id = add_arg(&fninfo, ARG_VGPR, ctx->i32);
-		ctx->param_gs_instance_id = add_arg(&fninfo, ARG_VGPR, ctx->i32);
+		add_arg_assign(&fninfo, ARG_VGPR, ctx->i32, &ctx->abi.gs_prim_id);
+		add_arg_assign(&fninfo, ARG_VGPR, ctx->i32, &ctx->abi.gs_invocation_id);
 		ctx->param_gs_vtx45_offset = add_arg(&fninfo, ARG_VGPR, ctx->i32);
 
 		if (ctx->type == PIPE_SHADER_VERTEX) {
@@ -4613,12 +4611,12 @@
 		/* VGPRs */
 		add_arg_assign(&fninfo, ARG_VGPR, ctx->i32, &ctx->gs_vtx_offset[0]);
 		add_arg_assign(&fninfo, ARG_VGPR, ctx->i32, &ctx->gs_vtx_offset[1]);
-		ctx->param_gs_prim_id = add_arg(&fninfo, ARG_VGPR, ctx->i32);
+		add_arg_assign(&fninfo, ARG_VGPR, ctx->i32, &ctx->abi.gs_prim_id);
 		add_arg_assign(&fninfo, ARG_VGPR, ctx->i32, &ctx->gs_vtx_offset[2]);
 		add_arg_assign(&fninfo, ARG_VGPR, ctx->i32, &ctx->gs_vtx_offset[3]);
 		add_arg_assign(&fninfo, ARG_VGPR, ctx->i32, &ctx->gs_vtx_offset[4]);
 		add_arg_assign(&fninfo, ARG_VGPR, ctx->i32, &ctx->gs_vtx_offset[5]);
-		ctx->param_gs_instance_id = add_arg(&fninfo, ARG_VGPR, ctx->i32);
+		add_arg_assign(&fninfo, ARG_VGPR, ctx->i32, &ctx->abi.gs_invocation_id);
 		break;
 
 	case PIPE_SHADER_FRAGMENT:
diff --git a/src/gallium/drivers/radeonsi/si_shader_internal.h b/src/gallium/drivers/radeonsi/si_shader_internal.h
index 7ff8815..ebe11fa 100644
--- a/src/gallium/drivers/radeonsi/si_shader_internal.h
+++ b/src/gallium/drivers/radeonsi/si_shader_internal.h
@@ -183,8 +183,6 @@
 	int param_gs2vs_offset;
 	int param_gs_wave_id; /* GFX6 */
 	LLVMValueRef gs_vtx_offset[6]; /* in dwords (GFX6) */
-	int param_gs_prim_id;
-	int param_gs_instance_id;
 	int param_gs_vtx01_offset; /* in dwords (GFX9) */
 	int param_gs_vtx23_offset; /* in dwords (GFX9) */
 	int param_gs_vtx45_offset; /* in dwords (GFX9) */