Usermode/libc - Fix strchr and strrchr behavior
[tpg/acess2.git] / Usermode / Libraries / ld-acess.so_src / loadlib.c
1 /*
2  AcessOS 1 - Dynamic Loader
3  By thePowersGang
4 */
5 #include "common.h"
6 #include <stdint.h>
7 #include <stdbool.h>
8 #include <acess/sys.h>
9
10 #define DEBUG   0
11
12 #if DEBUG
13 # define DEBUGS(v...)   SysDebug(v)
14 #else
15 # define DEBUGS(v...)   
16 #endif
17
18 #define MAX_QUEUED_ENTRYPOINTS  8
19
20 // === IMPORTS ===
21 extern const tLocalExport       caLocalExports[];
22 extern const int        ciNumLocalExports;
23 extern char     **gEnvP;
24 extern char     gLinkedBase[];
25
26 // === TYPES ===
27 typedef void    tLibEntry(void *, int, char *[], char**);
28
29 // === PROTOTYPES ===
30 void    *IsFileLoaded(const char *file);
31
32 // === GLOABLS ===
33 tLoadedLib      gLoadedLibraries[MAX_LOADED_LIBRARIES];
34 char    gsLoadedStrings[MAX_STRINGS_BYTES];
35 char    *gsNextAvailString = gsLoadedStrings;
36 struct sQueuedEntry {
37         void    *Base;
38         tLibEntry       *Entry;
39 }       gaQueuedEntrypoints[MAX_QUEUED_ENTRYPOINTS];
40  int    giNumQueuedEntrypoints;
41 //tLoadLib      *gpLoadedLibraries = NULL;
42
43 // === CODE ===
44 void ldacess_DumpLoadedLibraries(void)
45 {
46         for( int i = 0; i < MAX_LOADED_LIBRARIES; i ++ )
47         {
48                 const tLoadedLib* ll = &gLoadedLibraries[i];
49                 if(ll->Base == 0)       break;  // Last entry has Base set to NULL
50                 _SysDebug("%p: %s", ll->Base, ll->Name);
51         }
52 }
53
54 /**
55  * \brief Call queued up entry points (after relocations completed) 
56  */
57 void CallQueuedEntrypoints(char **EnvP)
58 {
59         while( giNumQueuedEntrypoints )
60         {
61                 giNumQueuedEntrypoints --;
62                 const struct sQueuedEntry       *qe = &gaQueuedEntrypoints[giNumQueuedEntrypoints];
63                 DEBUGS("Calling EP %p for %p", qe->Entry, qe->Base);
64                 qe->Entry(qe->Base, 0, NULL, EnvP);
65         }
66 }
67
68 const char *FindLibrary(char *DestBuf, const char *SoName, const char *ExtraSearchDir)
69 {       
70         // -- #1: Executable Specified
71         if(ExtraSearchDir)
72         {
73                 strcpy(DestBuf, ExtraSearchDir);
74                 strcat(DestBuf, "/");
75                 strcat(DestBuf, SoName);
76                 if(file_exists(DestBuf))        return DestBuf;
77         }
78         
79         // -- #2: System
80         strcpy(DestBuf, SYSTEM_LIB_DIR);
81         strcat(DestBuf, SoName);
82         if(file_exists(DestBuf))        return DestBuf;
83         
84         // -- #3: Current Directory
85         if(file_exists(SoName)) return SoName;
86         
87         return NULL;
88 }
89
90 /**
91  */
92 void *LoadLibrary(const char *SoName, const char *SearchDir, char **envp)
93 {
94         char    sTmpName[1024];
95         void    *base;
96         
97         DEBUGS("LoadLibrary: (SoName='%s', SearchDir='%s', envp=%p)", SoName, SearchDir, envp);
98         
99         // Create Temp Name
100         const char *filename = FindLibrary(sTmpName, SoName, SearchDir);
101         if(filename == NULL) {
102                 DEBUGS("LoadLibrary: RETURN 0");
103                 return 0;
104         }
105         DEBUGS(" LoadLibrary: filename='%s'", filename);
106         
107         if( (base = IsFileLoaded(filename)) )
108                 return base;
109
110         DEBUGS(" LoadLibrary: SysLoadBin()");   
111         // Load Library
112         tLibEntry       *fEntry;
113         base = _SysLoadBin(filename, (void**)&fEntry);
114         if(!base) {
115                 DEBUGS("LoadLibrary: RETURN 0");
116                 return 0;
117         }
118         
119         DEBUGS(" LoadLibrary: iArg=%p, fEntry=%p", base, fEntry);
120         
121         // Load Symbols
122         fEntry = DoRelocate( base, envp, filename );
123         if( !fEntry ) {
124                 return 0;
125         }
126         
127         // Call Entrypoint
128         // - TODO: Queue entrypoint calls
129         if( giNumQueuedEntrypoints >= MAX_QUEUED_ENTRYPOINTS ) {
130                 SysDebug("ERROR - Maximum number of queued entrypoints exceeded on %p '%s'",
131                         base, SoName);
132                 return 0;
133         }
134         gaQueuedEntrypoints[giNumQueuedEntrypoints].Base  = base;
135         gaQueuedEntrypoints[giNumQueuedEntrypoints].Entry = fEntry;
136         giNumQueuedEntrypoints ++;
137         
138         DEBUGS("LoadLibrary: RETURN success");
139         return base;
140 }
141
142 /**
143  * \fn Uint IsFileLoaded(char *file)
144  * \brief Determine if a file is already loaded
145  */
146 void *IsFileLoaded(const char *file)
147 {
148         DEBUGS("IsFileLoaded: (file='%s')", file);
149
150         // Applications link against either libld-acess.so or ld-acess.so
151         if( strcmp(file, "/Acess/Libs/libld-acess.so") == 0
152          || strcmp(file, "/Acess/Libs/ld-acess.so") == 0 )
153         {
154                 DEBUGS("IsFileLoaded: Found local (%p)", &gLinkedBase);
155                 return &gLinkedBase;
156         }
157
158         for( int i = 0; i < MAX_LOADED_LIBRARIES; i++ )
159         {
160                 if(gLoadedLibraries[i].Base == 0)       break;  // Last entry has Base set to NULL
161                 DEBUGS(" strcmp('%s', '%s')", gLoadedLibraries[i].Name, file);
162                 if(strcmp(gLoadedLibraries[i].Name, file) == 0) {
163                         DEBUGS("IsFileLoaded: Found %i (%p)", i, gLoadedLibraries[i].Base);
164                         return gLoadedLibraries[i].Base;
165                 }
166         }
167         DEBUGS("IsFileLoaded: Not Found");
168         return 0;
169 }
170
171 /**
172  * \fn void AddLoaded(char *File, Uint base)
173  * \brief Add a file to the loaded list
174  */
175 void AddLoaded(const char *File, void *base)
176 {
177          int    i, length;
178         char    *name = gsNextAvailString;
179         
180         DEBUGS("AddLoaded: (File='%s', base=%p)", File, base);
181         
182         // Find a free slot
183         for( i = 0; i < MAX_LOADED_LIBRARIES; i ++ )
184         {
185                 if(gLoadedLibraries[i].Base == 0)       break;
186         }
187         if(i == MAX_LOADED_LIBRARIES) {
188                 SysDebug("ERROR - ld-acess.so has run out of load slots!");
189                 return;
190         }
191         
192         // Check space in string buffer
193         length = strlen(File);
194         if(&name[length+1] >= &gsLoadedStrings[MAX_STRINGS_BYTES]) {
195                 SysDebug("ERROR - ld-acess.so has run out of string buffer memory!");
196                 return;
197         }
198         
199         // Set information
200         gLoadedLibraries[i].Base = base;
201         strcpy(name, File);
202         gLoadedLibraries[i].Name = name;
203         gsNextAvailString = &name[length+1];
204         DEBUGS("'%s' (%p) loaded as %i", name, base, i);
205         return;
206 }
207
208 /**
209  * \fn void Unload(Uint Base)
210  */
211 void Unload(void *Base)
212 {       
213          int    id;
214         char    *str;
215         for( id = 0; id < MAX_LOADED_LIBRARIES; id++ )
216         {
217                 if(gLoadedLibraries[id].Base == Base)   break;
218         }
219         if(id == MAX_LOADED_LIBRARIES)  return;
220         
221         // Unload Binary
222         _SysUnloadBin( Base );
223         // Save String Pointer
224         str = gLoadedLibraries[id].Name;
225         
226         // Compact Loaded List
227         int j = id;
228         for( int i = j + 1; i < MAX_LOADED_LIBRARIES; i++, j++ )
229         {
230                 if(gLoadedLibraries[i].Base == 0)       break;
231                 // Compact String
232                 strcpy(str, gLoadedLibraries[i].Name);
233                 str += strlen(str)+1;
234                 // Compact Entry
235                 gLoadedLibraries[j].Base = gLoadedLibraries[i].Base;
236                 gLoadedLibraries[j].Name = str;
237         }
238         
239         // NULL Last Entry
240         gLoadedLibraries[j].Base = 0;
241         gLoadedLibraries[j].Name = NULL;
242         // Save next string
243         gsNextAvailString = str;
244 }
245
246 /**
247  \fn Uint GetSymbol(const char *name)
248  \brief Gets a symbol value from a loaded library
249 */
250 int GetSymbol(const char *name, void **Value, size_t *Size, void *IgnoreBase)
251 {
252         ASSERT(name);
253         ASSERT(Value);
254         ASSERT(Size);
255         //SysDebug("GetSymbol: (%s)");
256         for( int i = 0; i < ciNumLocalExports; i ++ )
257         {
258                 const tLocalExport* le = &caLocalExports[i];
259                 if( strcmp(le->Name, name) == 0 ) {
260                         *Value = le->Value;
261                         if(Size)
262                                 *Size = 0;
263                         DEBUGS("'%s' = Local %p+%#x", name, le->Value, 0);
264                         return 1;
265                 }
266         }
267
268         bool have_weak = false; 
269         for(int i = 0; i < MAX_LOADED_LIBRARIES && gLoadedLibraries[i].Base != 0; i ++)
270         {
271                 const tLoadedLib* ll = &gLoadedLibraries[i];
272                 // Allow ignoring the current module
273                 if( ll->Base == IgnoreBase ) {
274                         //SysDebug("GetSymbol: Ignore %p", gLoadedLibraries[i].Base);
275                         continue ;
276                 }
277                 
278                 //SysDebug(" GetSymbol: Trying 0x%x, '%s'", ll->Base, ll->Name);
279                 void    *tmpval;
280                 size_t  tmpsize;
281                 int rv = GetSymbolFromBase(ll->Base, name, &tmpval, &tmpsize);
282                 if(rv)
283                 {
284                         *Value = tmpval;
285                         *Size = tmpsize;
286                         if( rv == 1 ) {
287                                 DEBUGS("'%s' = %p '%s' Strong %p+%#x", name, ll->Base, ll->Name, *Value, *Size);
288                                 return 1;
289                         }
290                         have_weak = true;
291                 }
292         }
293         if(have_weak) {
294                 DEBUGS("'%s' = Weak %p+%#x", name, *Value, *Size);
295                 return 2;
296         }
297         else {
298                 DEBUGS("'%s' = ?", name);
299                 return 0;
300         }
301 }
302
303 /**
304  \fn int GetSymbolFromBase(Uint base, char *name, Uint *ret)
305  \breif Gets a symbol from a specified library
306 */
307 int GetSymbolFromBase(void *base, const char *name, void **ret, size_t *Size)
308 {
309         uint8_t *hdr = base;
310         if(hdr[0] == 0x7F && hdr[1] == 'E' && hdr[2] == 'L' && hdr[3] == 'F')
311                 return ElfGetSymbol(base, name, ret, Size);
312         if(hdr[0] == 'M' && hdr[1] == 'Z')
313                 return PE_GetSymbol(base, name, ret, Size);
314         SysDebug("Unknown type at %p (%02x %02x %02x %02x)", base,
315                 hdr[0], hdr[1], hdr[2], hdr[3]);
316         return 0;
317 }
318

UCC git Repository :: git.ucc.asn.au