Merge branch 'master' of github.com:thepowersgang/acess2
[tpg/acess2.git] / Usermode / Libraries / ld-acess.so_src / loadlib.c
1 /*
2  AcessOS 1 - Dynamic Loader
3  By thePowersGang
4 */
5 #include "common.h"
6 #include <stdint.h>
7 #include <acess/sys.h>
8
9 #define DEBUG   0
10
11 #if DEBUG
12 # define DEBUGS(v...)   SysDebug(v)
13 #else
14 # define DEBUGS(v...)   
15 #endif
16
17 // === PROTOTYPES ===
18 void    *IsFileLoaded(const char *file);
19
20 // === IMPORTS ===
21 extern const struct {
22         void    *Value;
23         char    *Name;
24 }       caLocalExports[];
25 extern const int        ciNumLocalExports;
26 extern char     **gEnvP;
27 extern char     gLinkedBase[];
28
29 // === GLOABLS ===
30 tLoadedLib      gLoadedLibraries[MAX_LOADED_LIBRARIES];
31 char    gsLoadedStrings[MAX_STRINGS_BYTES];
32 char    *gsNextAvailString = gsLoadedStrings;
33 //tLoadLib      *gpLoadedLibraries = NULL;
34
35 // === CODE ===
36 const char *FindLibrary(char *DestBuf, const char *SoName, const char *ExtraSearchDir)
37 {       
38         // -- #1: Executable Specified
39         if(ExtraSearchDir)
40         {
41                 strcpy(DestBuf, ExtraSearchDir);
42                 strcat(DestBuf, "/");
43                 strcat(DestBuf, SoName);
44                 if(file_exists(DestBuf))        return DestBuf;
45         }
46         
47         // -- #2: System
48         strcpy(DestBuf, SYSTEM_LIB_DIR);
49         strcat(DestBuf, SoName);
50         if(file_exists(DestBuf))        return DestBuf;
51         
52         // -- #3: Current Directory
53         if(file_exists(SoName)) return SoName;
54         
55         return NULL;
56 }
57
58 /**
59  */
60 void *LoadLibrary(const char *SoName, const char *SearchDir, char **envp)
61 {
62         char    sTmpName[1024];
63         const char      *filename;
64         void    *base;
65         void    (*fEntry)(void *, int, char *[], char**);
66         
67         DEBUGS("LoadLibrary: (SoName='%s', SearchDir='%s', envp=%p)", SoName, SearchDir, envp);
68         
69         // Create Temp Name
70         filename = FindLibrary(sTmpName, SoName, SearchDir);
71         if(filename == NULL) {
72                 DEBUGS("LoadLibrary: RETURN 0");
73                 return 0;
74         }
75         DEBUGS(" LoadLibrary: filename='%s'", filename);
76         
77         if( (base = IsFileLoaded(filename)) )
78                 return base;
79
80         DEBUGS(" LoadLibrary: SysLoadBin()");   
81         // Load Library
82         base = _SysLoadBin(filename, (void**)&fEntry);
83         if(!base) {
84                 DEBUGS("LoadLibrary: RETURN 0");
85                 return 0;
86         }
87         
88         DEBUGS(" LoadLibrary: iArg=%p, fEntry=%p", base, fEntry);
89         
90         // Load Symbols
91         fEntry = DoRelocate( base, envp, filename );
92         if( !fEntry ) {
93                 return 0;
94         }
95         
96         // Call Entrypoint
97         DEBUGS(" LoadLibrary: '%s' Entry %p", SoName, fEntry);
98         fEntry(base, 0, NULL, gEnvP);
99         
100         DEBUGS("LoadLibrary: RETURN 1");
101         return base;
102 }
103
104 /**
105  * \fn Uint IsFileLoaded(char *file)
106  * \brief Determine if a file is already loaded
107  */
108 void *IsFileLoaded(const char *file)
109 {
110          int    i;
111         DEBUGS("IsFileLoaded: (file='%s')", file);
112
113         // Applications link against either libld-acess.so or ld-acess.so
114         if( strcmp(file, "/Acess/Libs/libld-acess.so") == 0
115          || strcmp(file, "/Acess/Libs/ld-acess.so") == 0 )
116         {
117                 DEBUGS("IsFileLoaded: Found local (%p)", &gLinkedBase);
118                 return &gLinkedBase;
119         }
120
121         for( i = 0; i < MAX_LOADED_LIBRARIES; i++ )
122         {
123                 if(gLoadedLibraries[i].Base == 0)       break;  // Last entry has Base set to NULL
124                 DEBUGS(" strcmp('%s', '%s')", gLoadedLibraries[i].Name, file);
125                 if(strcmp(gLoadedLibraries[i].Name, file) == 0) {
126                         DEBUGS("IsFileLoaded: Found %i (%p)", i, gLoadedLibraries[i].Base);
127                         return gLoadedLibraries[i].Base;
128                 }
129         }
130         DEBUGS("IsFileLoaded: Not Found");
131         return 0;
132 }
133
134 /**
135  * \fn void AddLoaded(char *File, Uint base)
136  * \brief Add a file to the loaded list
137  */
138 void AddLoaded(const char *File, void *base)
139 {
140          int    i, length;
141         char    *name = gsNextAvailString;
142         
143         DEBUGS("AddLoaded: (File='%s', base=%p)", File, base);
144         
145         // Find a free slot
146         for( i = 0; i < MAX_LOADED_LIBRARIES; i ++ )
147         {
148                 if(gLoadedLibraries[i].Base == 0)       break;
149         }
150         if(i == MAX_LOADED_LIBRARIES) {
151                 SysDebug("ERROR - ld-acess.so has run out of load slots!");
152                 return;
153         }
154         
155         // Check space in string buffer
156         length = strlen(File);
157         if(&name[length+1] >= &gsLoadedStrings[MAX_STRINGS_BYTES]) {
158                 SysDebug("ERROR - ld-acess.so has run out of string buffer memory!");
159                 return;
160         }
161         
162         // Set information
163         gLoadedLibraries[i].Base = base;
164         strcpy(name, File);
165         gLoadedLibraries[i].Name = name;
166         gsNextAvailString = &name[length+1];
167         DEBUGS("'%s' (%p) loaded as %i", name, base, i);
168         return;
169 }
170
171 /**
172  * \fn void Unload(Uint Base)
173  */
174 void Unload(void *Base)
175 {       
176          int    i, j;
177          int    id;
178         char    *str;
179         for( id = 0; id < MAX_LOADED_LIBRARIES; id++ )
180         {
181                 if(gLoadedLibraries[id].Base == Base)   break;
182         }
183         if(id == MAX_LOADED_LIBRARIES)  return;
184         
185         // Unload Binary
186         _SysUnloadBin( Base );
187         // Save String Pointer
188         str = gLoadedLibraries[id].Name;
189         
190         // Compact Loaded List
191         j = id;
192         for( i = j + 1; i < MAX_LOADED_LIBRARIES; i++, j++ )
193         {
194                 if(gLoadedLibraries[i].Base == 0)       break;
195                 // Compact String
196                 strcpy(str, gLoadedLibraries[i].Name);
197                 str += strlen(str)+1;
198                 // Compact Entry
199                 gLoadedLibraries[j].Base = gLoadedLibraries[i].Base;
200                 gLoadedLibraries[j].Name = str;
201         }
202         
203         // NULL Last Entry
204         gLoadedLibraries[j].Base = 0;
205         gLoadedLibraries[j].Name = NULL;
206         // Save next string
207         gsNextAvailString = str;
208 }
209
210 /**
211  \fn Uint GetSymbol(const char *name)
212  \brief Gets a symbol value from a loaded library
213 */
214 int GetSymbol(const char *name, void **Value, size_t *Size)
215 {
216          int    i;
217         
218         //SysDebug("ciNumLocalExports = %i", ciNumLocalExports);
219         for(i=0;i<ciNumLocalExports;i++)
220         {
221                 if( strcmp(caLocalExports[i].Name, name) == 0 ) {
222                         *Value = caLocalExports[i].Value;
223                         if(Size)
224                                 *Size = 0;
225                         return 1;
226                 }
227         }
228         
229         // Entry 0 is ld-acess, ignore it
230         for(i = 0; i < MAX_LOADED_LIBRARIES; i ++)
231         {
232                 if(gLoadedLibraries[i].Base == 0)
233                         break;
234                 
235                 //SysDebug(" GetSymbol: Trying 0x%x, '%s'",
236                 //      gLoadedLibraries[i].Base, gLoadedLibraries[i].Name);
237                 if(GetSymbolFromBase(gLoadedLibraries[i].Base, name, Value, Size))
238                         return 1;
239         }
240         SysDebug("GetSymbol: === Symbol '%s' not found ===", name);
241         return 0;
242 }
243
244 /**
245  \fn int GetSymbolFromBase(Uint base, char *name, Uint *ret)
246  \breif Gets a symbol from a specified library
247 */
248 int GetSymbolFromBase(void *base, const char *name, void **ret, size_t *Size)
249 {
250         uint8_t *hdr = base;
251         if(hdr[0] == 0x7F && hdr[1] == 'E' && hdr[2] == 'L' && hdr[3] == 'F')
252                 return ElfGetSymbol(base, name, ret, Size);
253         if(hdr[0] == 'M' && hdr[1] == 'Z')
254                 return PE_GetSymbol(base, name, ret, Size);
255         SysDebug("Unknown type at %p (%02x %02x %02x %02x)", base,
256                 hdr[0], hdr[1], hdr[2], hdr[3]);
257         return 0;
258 }
259

UCC git Repository :: git.ucc.asn.au