Usermode/ld-acess - Fix dynamic linking quirk (STB_WEAK and R_COPY)
[tpg/acess2.git] / Usermode / Libraries / ld-acess.so_src / loadlib.c
1 /*
2  AcessOS 1 - Dynamic Loader
3  By thePowersGang
4 */
5 #include "common.h"
6 #include <stdint.h>
7 #include <stdbool.h>
8 #include <acess/sys.h>
9
10 #define DEBUG   0
11
12 #if DEBUG
13 # define DEBUGS(v...)   SysDebug(v)
14 #else
15 # define DEBUGS(v...)   
16 #endif
17
18 #define MAX_QUEUED_ENTRYPOINTS  8
19
20 // === IMPORTS ===
21 extern const struct {
22         void    *Value;
23         char    *Name;
24 }       caLocalExports[];
25 extern const int        ciNumLocalExports;
26 extern char     **gEnvP;
27 extern char     gLinkedBase[];
28
29 // === TYPES ===
30 typedef void    tLibEntry(void *, int, char *[], char**);
31
32 // === PROTOTYPES ===
33 void    *IsFileLoaded(const char *file);
34
35 // === GLOABLS ===
36 tLoadedLib      gLoadedLibraries[MAX_LOADED_LIBRARIES];
37 char    gsLoadedStrings[MAX_STRINGS_BYTES];
38 char    *gsNextAvailString = gsLoadedStrings;
39 struct sQueuedEntry {
40         void    *Base;
41         tLibEntry       *Entry;
42 }       gaQueuedEntrypoints[MAX_QUEUED_ENTRYPOINTS];
43  int    giNumQueuedEntrypoints;
44 //tLoadLib      *gpLoadedLibraries = NULL;
45
46 // === CODE ===
47 void ldacess_DumpLoadedLibraries(void)
48 {
49         for( int i = 0; i < MAX_LOADED_LIBRARIES; i ++ )
50         {
51                 if(gLoadedLibraries[i].Base == 0)       break;  // Last entry has Base set to NULL
52                 _SysDebug("%p: %s",
53                         gLoadedLibraries[i].Base,
54                         gLoadedLibraries[i].Name
55                         );
56         }
57 }
58
59 /**
60  * \brief Call queued up entry points (after relocations completed) 
61  */
62 void CallQueuedEntrypoints(char **EnvP)
63 {
64         while( giNumQueuedEntrypoints )
65         {
66                 giNumQueuedEntrypoints --;
67                 const struct sQueuedEntry       *qe = &gaQueuedEntrypoints[giNumQueuedEntrypoints];
68                 //_SysDebug("Calling EP for %p", qe->Base);
69                 qe->Entry(qe->Base, 0, NULL, EnvP);
70         }
71 }
72
73 const char *FindLibrary(char *DestBuf, const char *SoName, const char *ExtraSearchDir)
74 {       
75         // -- #1: Executable Specified
76         if(ExtraSearchDir)
77         {
78                 strcpy(DestBuf, ExtraSearchDir);
79                 strcat(DestBuf, "/");
80                 strcat(DestBuf, SoName);
81                 if(file_exists(DestBuf))        return DestBuf;
82         }
83         
84         // -- #2: System
85         strcpy(DestBuf, SYSTEM_LIB_DIR);
86         strcat(DestBuf, SoName);
87         if(file_exists(DestBuf))        return DestBuf;
88         
89         // -- #3: Current Directory
90         if(file_exists(SoName)) return SoName;
91         
92         return NULL;
93 }
94
95 /**
96  */
97 void *LoadLibrary(const char *SoName, const char *SearchDir, char **envp)
98 {
99         char    sTmpName[1024];
100         void    *base;
101         
102         DEBUGS("LoadLibrary: (SoName='%s', SearchDir='%s', envp=%p)", SoName, SearchDir, envp);
103         
104         // Create Temp Name
105         const char *filename = FindLibrary(sTmpName, SoName, SearchDir);
106         if(filename == NULL) {
107                 DEBUGS("LoadLibrary: RETURN 0");
108                 return 0;
109         }
110         DEBUGS(" LoadLibrary: filename='%s'", filename);
111         
112         if( (base = IsFileLoaded(filename)) )
113                 return base;
114
115         DEBUGS(" LoadLibrary: SysLoadBin()");   
116         // Load Library
117         tLibEntry       *fEntry;
118         base = _SysLoadBin(filename, (void**)&fEntry);
119         if(!base) {
120                 DEBUGS("LoadLibrary: RETURN 0");
121                 return 0;
122         }
123         
124         DEBUGS(" LoadLibrary: iArg=%p, fEntry=%p", base, fEntry);
125         
126         // Load Symbols
127         fEntry = DoRelocate( base, envp, filename );
128         if( !fEntry ) {
129                 return 0;
130         }
131         
132         // Call Entrypoint
133         // - TODO: Queue entrypoint calls
134         if( giNumQueuedEntrypoints >= MAX_QUEUED_ENTRYPOINTS ) {
135                 SysDebug("ERROR - Maximum number of queued entrypoints exceeded on %p '%s'",
136                         base, SoName);
137                 return 0;
138         }
139         gaQueuedEntrypoints[giNumQueuedEntrypoints].Base  = base;
140         gaQueuedEntrypoints[giNumQueuedEntrypoints].Entry = fEntry;
141         giNumQueuedEntrypoints ++;
142         
143         DEBUGS("LoadLibrary: RETURN success");
144         return base;
145 }
146
147 /**
148  * \fn Uint IsFileLoaded(char *file)
149  * \brief Determine if a file is already loaded
150  */
151 void *IsFileLoaded(const char *file)
152 {
153          int    i;
154         DEBUGS("IsFileLoaded: (file='%s')", file);
155
156         // Applications link against either libld-acess.so or ld-acess.so
157         if( strcmp(file, "/Acess/Libs/libld-acess.so") == 0
158          || strcmp(file, "/Acess/Libs/ld-acess.so") == 0 )
159         {
160                 DEBUGS("IsFileLoaded: Found local (%p)", &gLinkedBase);
161                 return &gLinkedBase;
162         }
163
164         for( i = 0; i < MAX_LOADED_LIBRARIES; i++ )
165         {
166                 if(gLoadedLibraries[i].Base == 0)       break;  // Last entry has Base set to NULL
167                 DEBUGS(" strcmp('%s', '%s')", gLoadedLibraries[i].Name, file);
168                 if(strcmp(gLoadedLibraries[i].Name, file) == 0) {
169                         DEBUGS("IsFileLoaded: Found %i (%p)", i, gLoadedLibraries[i].Base);
170                         return gLoadedLibraries[i].Base;
171                 }
172         }
173         DEBUGS("IsFileLoaded: Not Found");
174         return 0;
175 }
176
177 /**
178  * \fn void AddLoaded(char *File, Uint base)
179  * \brief Add a file to the loaded list
180  */
181 void AddLoaded(const char *File, void *base)
182 {
183          int    i, length;
184         char    *name = gsNextAvailString;
185         
186         DEBUGS("AddLoaded: (File='%s', base=%p)", File, base);
187         
188         // Find a free slot
189         for( i = 0; i < MAX_LOADED_LIBRARIES; i ++ )
190         {
191                 if(gLoadedLibraries[i].Base == 0)       break;
192         }
193         if(i == MAX_LOADED_LIBRARIES) {
194                 SysDebug("ERROR - ld-acess.so has run out of load slots!");
195                 return;
196         }
197         
198         // Check space in string buffer
199         length = strlen(File);
200         if(&name[length+1] >= &gsLoadedStrings[MAX_STRINGS_BYTES]) {
201                 SysDebug("ERROR - ld-acess.so has run out of string buffer memory!");
202                 return;
203         }
204         
205         // Set information
206         gLoadedLibraries[i].Base = base;
207         strcpy(name, File);
208         gLoadedLibraries[i].Name = name;
209         gsNextAvailString = &name[length+1];
210         DEBUGS("'%s' (%p) loaded as %i", name, base, i);
211         return;
212 }
213
214 /**
215  * \fn void Unload(Uint Base)
216  */
217 void Unload(void *Base)
218 {       
219          int    i, j;
220          int    id;
221         char    *str;
222         for( id = 0; id < MAX_LOADED_LIBRARIES; id++ )
223         {
224                 if(gLoadedLibraries[id].Base == Base)   break;
225         }
226         if(id == MAX_LOADED_LIBRARIES)  return;
227         
228         // Unload Binary
229         _SysUnloadBin( Base );
230         // Save String Pointer
231         str = gLoadedLibraries[id].Name;
232         
233         // Compact Loaded List
234         j = id;
235         for( i = j + 1; i < MAX_LOADED_LIBRARIES; i++, j++ )
236         {
237                 if(gLoadedLibraries[i].Base == 0)       break;
238                 // Compact String
239                 strcpy(str, gLoadedLibraries[i].Name);
240                 str += strlen(str)+1;
241                 // Compact Entry
242                 gLoadedLibraries[j].Base = gLoadedLibraries[i].Base;
243                 gLoadedLibraries[j].Name = str;
244         }
245         
246         // NULL Last Entry
247         gLoadedLibraries[j].Base = 0;
248         gLoadedLibraries[j].Name = NULL;
249         // Save next string
250         gsNextAvailString = str;
251 }
252
253 /**
254  \fn Uint GetSymbol(const char *name)
255  \brief Gets a symbol value from a loaded library
256 */
257 int GetSymbol(const char *name, void **Value, size_t *Size, void *IgnoreBase)
258 {
259         //SysDebug("GetSymbol: (%s)");
260         for( int i = 0; i < ciNumLocalExports; i ++ )
261         {
262                 if( strcmp(caLocalExports[i].Name, name) == 0 ) {
263                         *Value = caLocalExports[i].Value;
264                         if(Size)
265                                 *Size = 0;
266                         //SysDebug("GetSymbol: Local %p+0x%x", *Value, 0);
267                         return 1;
268                 }
269         }
270
271         bool have_weak = false; 
272         for(int i = 0; i < MAX_LOADED_LIBRARIES && gLoadedLibraries[i].Base != 0; i ++)
273         {
274                 // Allow ignoring the current module
275                 if( gLoadedLibraries[i].Base == IgnoreBase ) {
276                         //SysDebug("GetSymbol: Ignore %p", gLoadedLibraries[i].Base);
277                         continue ;
278                 }
279                 
280                 //SysDebug(" GetSymbol: Trying 0x%x, '%s'",
281                 //      gLoadedLibraries[i].Base, gLoadedLibraries[i].Name);
282                 void    *tmpval;
283                 size_t  tmpsize;
284                 int rv = GetSymbolFromBase(gLoadedLibraries[i].Base, name, &tmpval, &tmpsize);
285                 if(rv)
286                 {
287                         *Value = tmpval;
288                         *Size = tmpsize;
289                         if( rv == 1 ) {
290                                 return 1;
291                         }
292                         have_weak = true;
293                 }
294         }
295         if(have_weak) {
296                 return 2;
297         }
298         else {
299                 return 0;
300         }
301 }
302
303 /**
304  \fn int GetSymbolFromBase(Uint base, char *name, Uint *ret)
305  \breif Gets a symbol from a specified library
306 */
307 int GetSymbolFromBase(void *base, const char *name, void **ret, size_t *Size)
308 {
309         uint8_t *hdr = base;
310         if(hdr[0] == 0x7F && hdr[1] == 'E' && hdr[2] == 'L' && hdr[3] == 'F')
311                 return ElfGetSymbol(base, name, ret, Size);
312         if(hdr[0] == 'M' && hdr[1] == 'Z')
313                 return PE_GetSymbol(base, name, ret, Size);
314         SysDebug("Unknown type at %p (%02x %02x %02x %02x)", base,
315                 hdr[0], hdr[1], hdr[2], hdr[3]);
316         return 0;
317 }
318

UCC git Repository :: git.ucc.asn.au